From 082a5a8fad3ee5d7cf23f6fc4d79cab4b2843e03 Mon Sep 17 00:00:00 2001 From: clssck Date: Fri, 5 Dec 2025 23:13:39 +0100 Subject: [PATCH] test(lightrag,api): add comprehensive test coverage and S3 support Add extensive test suites for API routes and utilities: - Implement test_search_routes.py (406 lines) for search endpoint validation - Implement test_upload_routes.py (724 lines) for document upload workflows - Implement test_s3_client.py (618 lines) for S3 storage operations - Implement test_citation_utils.py (352 lines) for citation extraction - Implement test_chunking.py (216 lines) for text chunking validation Add S3 storage client implementation: - Create lightrag/storage/s3_client.py with S3 operations - Add storage module initialization with exports - Integrate S3 client with document upload handling Enhance API routes and core functionality: - Add search_routes.py with full-text and graph search endpoints - Add upload_routes.py with multipart document upload support - Update operate.py with bulk operations and health checks - Enhance postgres_impl.py with bulk upsert and parameterized queries - Update lightrag_server.py to register new API routes - Improve utils.py with citation and formatting utilities Update dependencies and configuration: - Add S3 and test dependencies to pyproject.toml - Update docker-compose.test.yml for testing environment - Sync uv.lock with new dependencies Apply code quality improvements across all modified files: - Add type hints to function signatures - Update imports and router initialization - Fix logging and error handling --- docker-compose.test.yml | 30 + lightrag/api/lightrag_server.py | 55 +- lightrag/api/routers/__init__.py | 11 +- lightrag/api/routers/ollama_api.py | 4 +- lightrag/api/routers/query_routes.py | 52 +- lightrag/api/routers/search_routes.py | 155 ++++ lightrag/api/routers/upload_routes.py | 438 +++++++++++ lightrag/base.py | 13 +- lightrag/evaluation/e2e_test_harness.py | 1 + lightrag/kg/deprecated/chroma_impl.py | 5 +- lightrag/kg/postgres_impl.py | 212 ++++- lightrag/kg/redis_impl.py | 3 +- lightrag/llm/binding_options.py | 2 +- lightrag/llm/deprecated/siliconcloud.py | 3 +- lightrag/llm/llama_index_impl.py | 1 + lightrag/llm/nvidia_openai.py | 1 + lightrag/llm/zhipu.py | 4 +- lightrag/operate.py | 50 +- lightrag/storage/__init__.py | 5 + lightrag/storage/s3_client.py | 390 ++++++++++ .../lightrag_visualizer/graph_visualizer.py | 5 +- lightrag/utils.py | 73 +- pyproject.toml | 3 + reproduce/batch_eval.py | 1 - tests/test_chunking.py | 216 ++++++ tests/test_citation_utils.py | 352 +++++++++ tests/test_extraction_prompt_ab.py | 7 +- tests/test_prompt_accuracy.py | 1 - tests/test_prompt_quality_deep.py | 3 +- tests/test_s3_client.py | 618 +++++++++++++++ tests/test_search_routes.py | 406 ++++++++++ tests/test_upload_routes.py | 724 ++++++++++++++++++ uv.lock | 75 ++ 33 files changed, 3848 insertions(+), 71 deletions(-) create mode 100644 lightrag/api/routers/search_routes.py create mode 100644 lightrag/api/routers/upload_routes.py create mode 100644 lightrag/storage/__init__.py create mode 100644 lightrag/storage/s3_client.py create mode 100644 tests/test_citation_utils.py create mode 100644 tests/test_s3_client.py create mode 100644 tests/test_search_routes.py create mode 100644 tests/test_upload_routes.py diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 8286161f..c5eb215c 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -44,6 +44,26 @@ services: retries: 5 mem_limit: 2g + rustfs: + image: rustfs/rustfs:latest + container_name: rustfs-test + ports: + - "9000:9000" # S3 API + - "9001:9001" # Web console + environment: + RUSTFS_ACCESS_KEY: rustfsadmin + RUSTFS_SECRET_KEY: rustfsadmin + command: /data + volumes: + - rustfs_data:/data + healthcheck: + # RustFS returns AccessDenied for unauth requests, but that means it's alive + test: ["CMD-SHELL", "curl -s http://localhost:9000/ | grep -q 'AccessDenied' || curl -sf http://localhost:9000/"] + interval: 10s + timeout: 5s + retries: 5 + mem_limit: 512m + lightrag: container_name: lightrag-test build: @@ -118,9 +138,18 @@ services: - ORPHAN_CONNECTION_THRESHOLD=0.3 # Vector similarity pre-filter threshold - ORPHAN_CONFIDENCE_THRESHOLD=0.7 # LLM confidence required for connection - ORPHAN_CROSS_CONNECT=true # Allow orphan-to-orphan connections + + # S3/RustFS Configuration - Document staging and archival + - S3_ENDPOINT_URL=http://rustfs:9000 + - S3_ACCESS_KEY_ID=rustfsadmin + - S3_SECRET_ACCESS_KEY=rustfsadmin + - S3_BUCKET_NAME=lightrag + - S3_REGION=us-east-1 depends_on: postgres: condition: service_healthy + rustfs: + condition: service_healthy entrypoint: [] command: - python @@ -142,3 +171,4 @@ services: volumes: pgdata_test: + rustfs_data: diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index bb54e2a3..9a6f5b4c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,16 +3,18 @@ LightRAG FastAPI Server """ import argparse -from collections.abc import AsyncIterator import configparser -from contextlib import asynccontextmanager import logging import logging.config import os -from pathlib import Path import sys +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from pathlib import Path from typing import Annotated, Any, cast +import pipmaster as pm +import uvicorn from ascii_colors import ASCIIColors from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Request @@ -25,10 +27,9 @@ from fastapi.openapi.docs import ( from fastapi.responses import JSONResponse, RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.staticfiles import StaticFiles -import pipmaster as pm -import uvicorn -from lightrag import LightRAG, __version__ as core_version +from lightrag import LightRAG +from lightrag import __version__ as core_version from lightrag.api import __api_version__ from lightrag.api.auth import auth_handler from lightrag.api.routers.document_routes import ( @@ -38,7 +39,9 @@ from lightrag.api.routers.document_routes import ( from lightrag.api.routers.graph_routes import create_graph_routes from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.api.routers.query_routes import create_query_routes +from lightrag.api.routers.search_routes import create_search_routes from lightrag.api.routers.table_routes import create_table_routes +from lightrag.api.routers.upload_routes import create_upload_routes from lightrag.api.utils_api import ( check_env_file, display_splash_screen, @@ -59,6 +62,7 @@ from lightrag.kg.shared_storage import ( get_default_workspace, get_namespace_data, ) +from lightrag.storage.s3_client import S3Client, S3Config from lightrag.types import GPTKeywordExtractionFormat from lightrag.utils import EmbeddingFunc, get_env_value, logger, set_verbose_debug @@ -318,6 +322,18 @@ def create_app(args): # Initialize document manager with workspace support for data isolation doc_manager = DocumentManager(args.input_dir, workspace=args.workspace) + # Initialize S3 client if configured (for upload routes) + s3_client: S3Client | None = None + s3_endpoint_url = os.getenv('S3_ENDPOINT_URL', '') + if s3_endpoint_url: + try: + s3_config = S3Config(endpoint_url=s3_endpoint_url) + s3_client = S3Client(s3_config) + logger.info(f'S3 client configured for endpoint: {s3_endpoint_url}') + except ValueError as e: + logger.warning(f'S3 client not initialized: {e}') + s3_client = None + @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" @@ -329,6 +345,11 @@ def create_app(args): # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() + # Initialize S3 client if configured + if s3_client is not None: + await s3_client.initialize() + logger.info('S3 client initialized successfully') + # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -337,6 +358,11 @@ def create_app(args): yield finally: + # Finalize S3 client if initialized + if s3_client is not None: + await s3_client.finalize() + logger.info('S3 client finalized') + # Clean up database connections await rag.finalize_storages() @@ -1006,7 +1032,7 @@ def create_app(args): api_key, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_query_routes(rag, api_key, args.top_k, s3_client)) app.include_router(create_graph_routes(rag, api_key)) # Register table routes if all storages are PostgreSQL @@ -1023,6 +1049,21 @@ def create_app(args): ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix='/api') + # Register upload routes if S3 is configured + if s3_client is not None: + app.include_router(create_upload_routes(rag, s3_client, api_key)) + logger.info('S3 upload routes registered at /upload') + else: + logger.info('S3 not configured - upload routes disabled') + + # Register BM25 search routes if PostgreSQL storage is configured + # Full-text search requires PostgreSQLDB for ts_rank queries + if args.kv_storage == 'PGKVStorage' and hasattr(rag, 'text_chunks') and hasattr(rag.text_chunks, 'db'): + app.include_router(create_search_routes(rag.text_chunks.db, api_key)) + logger.info('BM25 search routes registered at /search') + else: + logger.info('PostgreSQL not configured - BM25 search routes disabled') + # Custom Swagger UI endpoint for offline support @app.get('/docs', include_in_schema=False) async def custom_swagger_ui_html(): diff --git a/lightrag/api/routers/__init__.py b/lightrag/api/routers/__init__.py index 0beaebfd..4a685235 100644 --- a/lightrag/api/routers/__init__.py +++ b/lightrag/api/routers/__init__.py @@ -6,5 +6,14 @@ from .document_routes import router as document_router from .graph_routes import router as graph_router from .ollama_api import OllamaAPI from .query_routes import router as query_router +from .search_routes import create_search_routes +from .upload_routes import create_upload_routes -__all__ = ['OllamaAPI', 'document_router', 'graph_router', 'query_router'] +__all__ = [ + 'OllamaAPI', + 'document_router', + 'graph_router', + 'query_router', + 'create_search_routes', + 'create_upload_routes', +] diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index c4e412dd..13ea64c6 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,9 +1,9 @@ import asyncio -from collections.abc import AsyncIterator, Awaitable, Callable -from enum import Enum import json import re import time +from collections.abc import AsyncIterator, Awaitable, Callable +from enum import Enum from typing import Any, TypeVar, cast from fastapi import APIRouter, Depends, HTTPException, Request diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 38c11adf..7638c78a 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -285,6 +285,8 @@ class EnhancedReferenceItem(BaseModel): section_title: str | None = Field(default=None, description='Section or chapter title') page_range: str | None = Field(default=None, description='Page range (e.g., pp. 45-67)') excerpt: str | None = Field(default=None, description='Brief excerpt from the source') + s3_key: str | None = Field(default=None, description='S3 object key for source document') + presigned_url: str | None = Field(default=None, description='Presigned URL for direct access') async def _extract_and_stream_citations( @@ -294,6 +296,7 @@ async def _extract_and_stream_citations( rag, min_similarity: float, citation_mode: str, + s3_client=None, ): """Extract citations from response and yield NDJSON lines. @@ -305,10 +308,11 @@ async def _extract_and_stream_citations( Args: response: The full LLM response text chunks: List of chunk dictionaries from retrieval - references: List of reference dicts + references: List of reference dicts with s3_key and document metadata rag: The RAG instance (for embedding function) min_similarity: Minimum similarity threshold citation_mode: 'inline' or 'footnotes' + s3_client: Optional S3Client for generating presigned URLs Yields: NDJSON lines for citation metadata (no duplicate text) @@ -339,19 +343,31 @@ async def _extract_and_stream_citations( } ) - # Build enhanced sources with metadata + # Build enhanced sources with metadata and presigned URLs sources = [] for ref in citation_result.references: - sources.append( - { - 'reference_id': ref.reference_id, - 'file_path': ref.file_path, - 'document_title': ref.document_title, - 'section_title': ref.section_title, - 'page_range': ref.page_range, - 'excerpt': ref.excerpt, - } + source_item = { + 'reference_id': ref.reference_id, + 'file_path': ref.file_path, + 'document_title': ref.document_title, + 'section_title': ref.section_title, + 'page_range': ref.page_range, + 'excerpt': ref.excerpt, + 's3_key': getattr(ref, 's3_key', None), + 'presigned_url': None, + } + + # Generate presigned URL if S3 client is available and s3_key exists + s3_key = getattr(ref, 's3_key', None) or ( + ref.__dict__.get('s3_key') if hasattr(ref, '__dict__') else None ) + if s3_client and s3_key: + try: + source_item['presigned_url'] = await s3_client.get_presigned_url(s3_key) + except Exception as e: + logger.debug(f'Failed to generate presigned URL for {s3_key}: {e}') + + sources.append(source_item) # Format footnotes if requested footnotes = citation_result.footnotes if citation_mode == 'footnotes' else [] @@ -363,7 +379,7 @@ async def _extract_and_stream_citations( { 'citations_metadata': { 'markers': citation_markers, # Position-based markers for insertion - 'sources': sources, # Enhanced reference metadata + 'sources': sources, # Enhanced reference metadata with presigned URLs 'footnotes': footnotes, # Pre-formatted footnote strings 'uncited_count': len(citation_result.uncited_claims), } @@ -380,7 +396,15 @@ async def _extract_and_stream_citations( yield json.dumps({'citation_error': str(e)}) + '\n' -def create_query_routes(rag, api_key: str | None = None, top_k: int = DEFAULT_TOP_K): +def create_query_routes(rag, api_key: str | None = None, top_k: int = DEFAULT_TOP_K, s3_client=None): + """Create query routes with optional S3 client for presigned URL generation in citations. + + Args: + rag: LightRAG instance + api_key: Optional API key for authentication + top_k: Default top_k for retrieval + s3_client: Optional S3Client for generating presigned URLs in citation responses + """ combined_auth = get_combined_auth_dependency(api_key) @router.post( @@ -911,6 +935,7 @@ def create_query_routes(rag, api_key: str | None = None, top_k: int = DEFAULT_TO rag, request.citation_threshold or 0.7, citation_mode, + s3_client, ): yield line else: @@ -938,6 +963,7 @@ def create_query_routes(rag, api_key: str | None = None, top_k: int = DEFAULT_TO rag, request.citation_threshold or 0.7, citation_mode, + s3_client, ): yield line diff --git a/lightrag/api/routers/search_routes.py b/lightrag/api/routers/search_routes.py new file mode 100644 index 00000000..84921777 --- /dev/null +++ b/lightrag/api/routers/search_routes.py @@ -0,0 +1,155 @@ +""" +Search routes for BM25 full-text search. + +This module provides a direct keyword search endpoint that bypasses the LLM +and returns matching chunks directly. This is complementary to the /query +endpoint which uses semantic RAG. + +Use cases: +- /query (existing): Semantic RAG - "What causes X?" -> LLM-generated answer +- /search (this): Keyword search - "Find docs about X" -> Direct chunk results +""" + +from typing import Annotated, Any, ClassVar + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field + +from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.kg.postgres_impl import PostgreSQLDB +from lightrag.utils import logger + + +class SearchResult(BaseModel): + """A single search result (chunk match).""" + + id: str = Field(description='Chunk ID') + full_doc_id: str = Field(description='Parent document ID') + chunk_order_index: int = Field(description='Position in document') + tokens: int = Field(description='Token count') + content: str = Field(description='Chunk content') + file_path: str | None = Field(default=None, description='Source file path') + s3_key: str | None = Field(default=None, description='S3 key for source document') + char_start: int | None = Field(default=None, description='Character offset start in source document') + char_end: int | None = Field(default=None, description='Character offset end in source document') + score: float = Field(description='BM25 relevance score') + + +class SearchResponse(BaseModel): + """Response model for search endpoint.""" + + query: str = Field(description='Original search query') + results: list[SearchResult] = Field(description='Matching chunks') + count: int = Field(description='Number of results returned') + workspace: str = Field(description='Workspace searched') + + class Config: + json_schema_extra: ClassVar[dict[str, Any]] = { + 'example': { + 'query': 'machine learning algorithms', + 'results': [ + { + 'id': 'chunk-abc123', + 'full_doc_id': 'doc-xyz789', + 'chunk_order_index': 5, + 'tokens': 250, + 'content': 'Machine learning algorithms can be categorized into...', + 'file_path': 's3://lightrag/archive/default/doc-xyz789/report.pdf', + 's3_key': 'archive/default/doc-xyz789/report.pdf', + 'score': 0.85, + } + ], + 'count': 1, + 'workspace': 'default', + } + } + + +def create_search_routes( + db: PostgreSQLDB, + api_key: str | None = None, +) -> APIRouter: + """ + Create search routes for BM25 full-text search. + + Args: + db: PostgreSQLDB instance for executing searches + api_key: Optional API key for authentication + + Returns: + FastAPI router with search endpoints + """ + router = APIRouter( + prefix='/search', + tags=['search'], + ) + + optional_api_key = get_combined_auth_dependency(api_key) + + @router.get( + '', + response_model=SearchResponse, + summary='BM25 keyword search', + description=""" + Perform BM25-style full-text search on document chunks. + + This endpoint provides direct keyword search without LLM processing. + It's faster than /query for simple keyword lookups and returns + matching chunks directly. + + The search uses PostgreSQL's native full-text search with ts_rank + for relevance scoring. + + **Use cases:** + - Quick keyword lookups + - Finding specific terms or phrases + - Browsing chunks containing specific content + - Export/citation workflows where you need exact matches + + **Differences from /query:** + - /query: Semantic search + LLM generation -> Natural language answer + - /search: Keyword search -> Direct chunk results (no LLM) + """, + ) + async def search( + q: Annotated[str, Query(description='Search query', min_length=1)], + limit: Annotated[int, Query(description='Max results to return', ge=1, le=100)] = 10, + workspace: Annotated[str, Query(description='Workspace to search')] = 'default', + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> SearchResponse: + """Perform BM25 full-text search on chunks.""" + try: + results = await db.full_text_search( + query=q, + workspace=workspace, + limit=limit, + ) + + search_results = [ + SearchResult( + id=r.get('id', ''), + full_doc_id=r.get('full_doc_id', ''), + chunk_order_index=r.get('chunk_order_index', 0), + tokens=r.get('tokens', 0), + content=r.get('content', ''), + file_path=r.get('file_path'), + s3_key=r.get('s3_key'), + char_start=r.get('char_start'), + char_end=r.get('char_end'), + score=float(r.get('score', 0)), + ) + for r in results + ] + + return SearchResponse( + query=q, + results=search_results, + count=len(search_results), + workspace=workspace, + ) + + except Exception as e: + logger.error(f'Search failed: {e}') + raise HTTPException(status_code=500, detail=f'Search failed: {e}') from e + + return router diff --git a/lightrag/api/routers/upload_routes.py b/lightrag/api/routers/upload_routes.py new file mode 100644 index 00000000..2d6fb20e --- /dev/null +++ b/lightrag/api/routers/upload_routes.py @@ -0,0 +1,438 @@ +""" +Upload routes for S3/RustFS document staging. + +This module provides endpoints for: +- Uploading documents to S3 staging +- Listing staged documents +- Getting presigned URLs +""" + +import mimetypes +from typing import Annotated, Any, ClassVar + +from fastapi import ( + APIRouter, + Depends, + File, + Form, + HTTPException, + UploadFile, +) +from pydantic import BaseModel, Field + +from lightrag import LightRAG +from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.storage.s3_client import S3Client +from lightrag.utils import compute_mdhash_id, logger + + +class UploadResponse(BaseModel): + """Response model for document upload.""" + + status: str = Field(description='Upload status') + doc_id: str = Field(description='Document ID') + s3_key: str = Field(description='S3 object key') + s3_url: str = Field(description='S3 URL (s3://bucket/key)') + message: str | None = Field(default=None, description='Additional message') + + class Config: + json_schema_extra: ClassVar[dict[str, Any]] = { + 'example': { + 'status': 'uploaded', + 'doc_id': 'doc_abc123', + 's3_key': 'staging/default/doc_abc123/report.pdf', + 's3_url': 's3://lightrag/staging/default/doc_abc123/report.pdf', + 'message': 'Document staged for processing', + } + } + + +class StagedDocument(BaseModel): + """Model for a staged document.""" + + key: str = Field(description='S3 object key') + size: int = Field(description='File size in bytes') + last_modified: str = Field(description='Last modified timestamp') + + +class ListStagedResponse(BaseModel): + """Response model for listing staged documents.""" + + workspace: str = Field(description='Workspace name') + documents: list[StagedDocument] = Field(description='List of staged documents') + count: int = Field(description='Number of documents') + + +class PresignedUrlResponse(BaseModel): + """Response model for presigned URL.""" + + s3_key: str = Field(description='S3 object key') + presigned_url: str = Field(description='Presigned URL for direct access') + expiry_seconds: int = Field(description='URL expiry time in seconds') + + +class ProcessS3Request(BaseModel): + """Request model for processing a document from S3 staging.""" + + s3_key: str = Field(description='S3 key of the staged document') + doc_id: str | None = Field( + default=None, + description='Document ID (extracted from s3_key if not provided)', + ) + archive_after_processing: bool = Field( + default=True, + description='Move document to archive after successful processing', + ) + + class Config: + json_schema_extra: ClassVar[dict[str, Any]] = { + 'example': { + 's3_key': 'staging/default/doc_abc123/report.pdf', + 'doc_id': 'doc_abc123', + 'archive_after_processing': True, + } + } + + +class ProcessS3Response(BaseModel): + """Response model for S3 document processing.""" + + status: str = Field(description='Processing status') + track_id: str = Field(description='Track ID for monitoring processing progress') + doc_id: str = Field(description='Document ID') + s3_key: str = Field(description='Original S3 key') + archive_key: str | None = Field(default=None, description='Archive S3 key (if archived)') + message: str | None = Field(default=None, description='Additional message') + + class Config: + json_schema_extra: ClassVar[dict[str, Any]] = { + 'example': { + 'status': 'processing_started', + 'track_id': 'insert_20250101_120000_abc123', + 'doc_id': 'doc_abc123', + 's3_key': 'staging/default/doc_abc123/report.pdf', + 'archive_key': 'archive/default/doc_abc123/report.pdf', + 'message': 'Document processing started', + } + } + + +def create_upload_routes( + rag: LightRAG, + s3_client: S3Client, + api_key: str | None = None, +) -> APIRouter: + """ + Create upload routes for S3 document staging. + + Args: + rag: LightRAG instance + s3_client: Initialized S3Client instance + api_key: Optional API key for authentication + + Returns: + FastAPI router with upload endpoints + """ + router = APIRouter( + prefix='/upload', + tags=['upload'], + ) + + optional_api_key = get_combined_auth_dependency(api_key) + + @router.post( + '', + response_model=UploadResponse, + summary='Upload document to S3 staging', + description=""" + Upload a document to S3/RustFS staging area. + + The document will be staged at: s3://bucket/staging/{workspace}/{doc_id}/{filename} + + After upload, the document can be processed by calling the standard document + processing endpoints, which will: + 1. Fetch the document from S3 staging + 2. Process it through the RAG pipeline + 3. Move it to S3 archive + 4. Store processed data in PostgreSQL + """, + ) + async def upload_document( + file: Annotated[UploadFile, File(description='Document file to upload')], + workspace: Annotated[str, Form(description='Workspace name')] = 'default', + doc_id: Annotated[str | None, Form(description='Optional document ID (auto-generated if not provided)')] = None, + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> UploadResponse: + """Upload a document to S3 staging.""" + try: + # Read file content + content = await file.read() + + if not content: + raise HTTPException(status_code=400, detail='Empty file') + + # Generate doc_id if not provided + if not doc_id: + doc_id = compute_mdhash_id(content, prefix='doc_') + + # Determine content type + content_type = file.content_type + if not content_type: + content_type, _ = mimetypes.guess_type(file.filename or '') + content_type = content_type or 'application/octet-stream' + + # Upload to S3 staging + s3_key = await s3_client.upload_to_staging( + workspace=workspace, + doc_id=doc_id, + content=content, + filename=file.filename or f'{doc_id}.bin', + content_type=content_type, + metadata={ + 'original_size': str(len(content)), + 'content_type': content_type, + }, + ) + + s3_url = s3_client.get_s3_url(s3_key) + + logger.info(f'Document uploaded to staging: {s3_key}') + + return UploadResponse( + status='uploaded', + doc_id=doc_id, + s3_key=s3_key, + s3_url=s3_url, + message='Document staged for processing', + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f'Upload failed: {e}') + raise HTTPException(status_code=500, detail=f'Upload failed: {e}') from e + + @router.get( + '/staged', + response_model=ListStagedResponse, + summary='List staged documents', + description='List all documents in the staging area for a workspace.', + ) + async def list_staged( + workspace: str = 'default', + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> ListStagedResponse: + """List documents in staging.""" + try: + objects = await s3_client.list_staging(workspace) + + documents = [ + StagedDocument( + key=obj['key'], + size=obj['size'], + last_modified=obj['last_modified'], + ) + for obj in objects + ] + + return ListStagedResponse( + workspace=workspace, + documents=documents, + count=len(documents), + ) + + except Exception as e: + logger.error(f'Failed to list staged documents: {e}') + raise HTTPException(status_code=500, detail=f'Failed to list staged documents: {e}') from e + + @router.get( + '/presigned-url', + response_model=PresignedUrlResponse, + summary='Get presigned URL', + description='Generate a presigned URL for direct access to a document in S3.', + ) + async def get_presigned_url( + s3_key: str, + expiry: int = 3600, + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> PresignedUrlResponse: + """Get presigned URL for a document.""" + try: + # Verify object exists + if not await s3_client.object_exists(s3_key): + raise HTTPException(status_code=404, detail='Object not found') + + url = await s3_client.get_presigned_url(s3_key, expiry=expiry) + + return PresignedUrlResponse( + s3_key=s3_key, + presigned_url=url, + expiry_seconds=expiry, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f'Failed to generate presigned URL: {e}') + raise HTTPException(status_code=500, detail=f'Failed to generate presigned URL: {e}') from e + + @router.delete( + '/staged/{doc_id}', + summary='Delete staged document', + description='Delete a document from the staging area.', + ) + async def delete_staged( + doc_id: str, + workspace: str = 'default', + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> dict[str, str]: + """Delete a staged document.""" + try: + # List objects with this doc_id prefix + prefix = f'staging/{workspace}/{doc_id}/' + objects = await s3_client.list_staging(workspace) + + # Filter to this doc_id + to_delete = [obj['key'] for obj in objects if obj['key'].startswith(prefix)] + + if not to_delete: + raise HTTPException(status_code=404, detail='Document not found in staging') + + # Delete each object + for key in to_delete: + await s3_client.delete_object(key) + + return { + 'status': 'deleted', + 'doc_id': doc_id, + 'deleted_count': str(len(to_delete)), + } + + except HTTPException: + raise + except Exception as e: + logger.error(f'Failed to delete staged document: {e}') + raise HTTPException(status_code=500, detail=f'Failed to delete staged document: {e}') from e + + @router.post( + '/process', + response_model=ProcessS3Response, + summary='Process document from S3 staging', + description=""" + Fetch a document from S3 staging and process it through the RAG pipeline. + + This endpoint: + 1. Fetches the document content from S3 staging + 2. Processes it through the RAG pipeline (chunking, entity extraction, embedding) + 3. Stores processed data in PostgreSQL with s3_key reference + 4. Optionally moves the document from staging to archive + + The s3_key should be the full key returned from the upload endpoint, + e.g., "staging/default/doc_abc123/report.pdf" + """, + ) + async def process_from_s3( + request: ProcessS3Request, + _: Annotated[bool, Depends(optional_api_key)] = True, + ) -> ProcessS3Response: + """Process a staged document through the RAG pipeline.""" + try: + s3_key = request.s3_key + + # Verify object exists + if not await s3_client.object_exists(s3_key): + raise HTTPException( + status_code=404, + detail=f'Document not found in S3: {s3_key}', + ) + + # Fetch content from S3 + content_bytes, metadata = await s3_client.get_object(s3_key) + + # Extract doc_id from s3_key if not provided + # s3_key format: staging/{workspace}/{doc_id}/{filename} + doc_id = request.doc_id + if not doc_id: + parts = s3_key.split('/') + doc_id = parts[2] if len(parts) >= 3 else compute_mdhash_id(content_bytes, prefix='doc_') + + # Determine content type and decode appropriately + content_type = metadata.get('content_type', 'application/octet-stream') + s3_url = s3_client.get_s3_url(s3_key) + + # For text-based content, decode to string + if content_type.startswith('text/') or content_type in ( + 'application/json', + 'application/xml', + 'application/javascript', + ): + try: + text_content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + text_content = content_bytes.decode('latin-1') + else: + # For binary content (PDF, Word, etc.), we need document parsing + # For now, attempt UTF-8 decode or fail gracefully + try: + text_content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + raise HTTPException( + status_code=400, + detail=f'Cannot process binary content type: {content_type}. ' + 'Document parsing for PDF/Word not yet implemented.', + ) from None + + if not text_content.strip(): + raise HTTPException( + status_code=400, + detail='Document content is empty after decoding', + ) + + # Process through RAG pipeline + # Use s3_url as file_path for citation reference + logger.info(f'Processing S3 document: {s3_key} (doc_id: {doc_id})') + + track_id = await rag.ainsert( + input=text_content, + ids=doc_id, + file_paths=s3_url, + ) + + # Move to archive if requested + archive_key = None + if request.archive_after_processing: + try: + archive_key = await s3_client.move_to_archive(s3_key) + logger.info(f'Moved to archive: {s3_key} -> {archive_key}') + + # Update database chunks with archive s3_key + archive_url = s3_client.get_s3_url(archive_key) + updated_count = await rag.text_chunks.update_s3_key_by_doc_id( + full_doc_id=doc_id, + s3_key=archive_key, + archive_url=archive_url, + ) + logger.info(f'Updated {updated_count} chunks with archive s3_key: {archive_key}') + except Exception as e: + logger.warning(f'Failed to archive document: {e}') + # Don't fail the request, processing succeeded + + return ProcessS3Response( + status='processing_complete', + track_id=track_id, + doc_id=doc_id, + s3_key=s3_key, + archive_key=archive_key, + message='Document processed and stored in RAG pipeline', + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f'Failed to process S3 document: {e}') + raise HTTPException( + status_code=500, + detail=f'Failed to process S3 document: {e}', + ) from e + + return router diff --git a/lightrag/base.py b/lightrag/base.py index 04eb4171..d1b101ac 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -65,11 +65,22 @@ class OllamaServerInfos: return f'{self._lightrag_name}:{self._lightrag_tag}' -class TextChunkSchema(TypedDict): +class TextChunkSchema(TypedDict, total=False): + """Schema for text chunks with optional position metadata. + + Required fields: tokens, content, full_doc_id, chunk_order_index + Optional fields: file_path, s3_key, char_start, char_end + """ + tokens: int content: str full_doc_id: str chunk_order_index: int + # Optional fields for citation support + file_path: str | None + s3_key: str | None + char_start: int | None # Character offset start in source document + char_end: int | None # Character offset end in source document T = TypeVar('T') diff --git a/lightrag/evaluation/e2e_test_harness.py b/lightrag/evaluation/e2e_test_harness.py index 1447e912..2d3b2254 100644 --- a/lightrag/evaluation/e2e_test_harness.py +++ b/lightrag/evaluation/e2e_test_harness.py @@ -36,6 +36,7 @@ from pathlib import Path import httpx from dotenv import load_dotenv + from lightrag.utils import logger # Add parent directory to path diff --git a/lightrag/kg/deprecated/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py index 25b20419..d2a207a9 100644 --- a/lightrag/kg/deprecated/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -4,13 +4,12 @@ from dataclasses import dataclass from typing import Any, final import numpy as np +from chromadb import HttpClient, PersistentClient # type: ignore +from chromadb.config import Settings # type: ignore from lightrag.base import BaseVectorStorage from lightrag.utils import logger -from chromadb import HttpClient, PersistentClient # type: ignore -from chromadb.config import Settings # type: ignore - @final @dataclass diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 63ab0639..daaba491 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -835,6 +835,66 @@ class PostgreSQLDB: except Exception as e: logger.warning(f'Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}') + async def _migrate_add_s3_key_columns(self): + """Add s3_key column to LIGHTRAG_DOC_FULL and LIGHTRAG_DOC_CHUNKS tables if they don't exist""" + tables = [ + ('lightrag_doc_full', 'LIGHTRAG_DOC_FULL'), + ('lightrag_doc_chunks', 'LIGHTRAG_DOC_CHUNKS'), + ] + + for table_name_lower, table_name in tables: + try: + check_column_sql = f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{table_name_lower}' + AND column_name = 's3_key' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info(f'Adding s3_key column to {table_name} table') + add_column_sql = f""" + ALTER TABLE {table_name} + ADD COLUMN s3_key TEXT NULL + """ + await self.execute(add_column_sql) + logger.info(f'Successfully added s3_key column to {table_name} table') + else: + logger.info(f's3_key column already exists in {table_name} table') + except Exception as e: + logger.warning(f'Failed to add s3_key column to {table_name}: {e}') + + async def _migrate_add_chunk_position_columns(self): + """Add char_start and char_end columns to LIGHTRAG_DOC_CHUNKS table if they don't exist""" + columns = [ + ('char_start', 'INTEGER NULL'), + ('char_end', 'INTEGER NULL'), + ] + + for column_name, column_type in columns: + try: + check_column_sql = f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_doc_chunks' + AND column_name = '{column_name}' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info(f'Adding {column_name} column to LIGHTRAG_DOC_CHUNKS table') + add_column_sql = f""" + ALTER TABLE LIGHTRAG_DOC_CHUNKS + ADD COLUMN {column_name} {column_type} + """ + await self.execute(add_column_sql) + logger.info(f'Successfully added {column_name} column to LIGHTRAG_DOC_CHUNKS table') + else: + logger.info(f'{column_name} column already exists in LIGHTRAG_DOC_CHUNKS table') + except Exception as e: + logger.warning(f'Failed to add {column_name} column to LIGHTRAG_DOC_CHUNKS: {e}') + async def _migrate_doc_status_add_track_id(self): """Add track_id column to LIGHTRAG_DOC_STATUS table if it doesn't exist and create index""" try: @@ -1114,9 +1174,16 @@ class PostgreSQLDB: ] # GIN indexes for array membership queries (chunk_ids lookups) + # and full-text search (BM25-style keyword search) gin_indexes = [ ('idx_lightrag_vdb_entity_chunk_ids_gin', 'LIGHTRAG_VDB_ENTITY', 'USING gin (chunk_ids)'), ('idx_lightrag_vdb_relation_chunk_ids_gin', 'LIGHTRAG_VDB_RELATION', 'USING gin (chunk_ids)'), + # Full-text search GIN index for BM25 keyword search on chunks + ( + 'idx_lightrag_doc_chunks_content_fts_gin', + 'LIGHTRAG_DOC_CHUNKS', + "USING gin (to_tsvector('english', content))", + ), ] # Create GIN indexes separately (different syntax) @@ -1197,6 +1264,18 @@ class PostgreSQLDB: 'Failed to migrate text chunks llm_cache_list field', ) + await self._run_migration( + self._migrate_add_s3_key_columns(), + 's3_key_columns', + 'Failed to add s3_key columns to doc tables', + ) + + await self._run_migration( + self._migrate_add_chunk_position_columns(), + 'chunk_position_columns', + 'Failed to add char_start/char_end columns to doc chunks table', + ) + await self._run_migration( self._migrate_field_lengths(), 'field_lengths', @@ -1617,6 +1696,55 @@ class PostgreSQLDB: logger.error(f'PostgreSQL executemany error: {e}, sql: {sql[:100]}...') raise + async def full_text_search( + self, + query: str, + workspace: str | None = None, + limit: int = 10, + language: str = 'english', + ) -> list[dict[str, Any]]: + """Perform BM25-style full-text search on document chunks. + + Uses PostgreSQL's native full-text search with ts_rank for ranking. + + Args: + query: Search query string + workspace: Optional workspace filter (uses self.workspace if not provided) + limit: Maximum number of results to return + language: Language for text search configuration (default: english) + + Returns: + List of matching chunks with content, score, metadata, and s3_key + """ + ws = workspace or self.workspace + + # Use plainto_tsquery for simple query parsing (handles plain text queries) + # websearch_to_tsquery could be used for more advanced syntax + sql = f""" + SELECT + id, + full_doc_id, + chunk_order_index, + tokens, + content, + file_path, + s3_key, + char_start, + char_end, + ts_rank( + to_tsvector('{language}', content), + plainto_tsquery('{language}', $1) + ) AS score + FROM LIGHTRAG_DOC_CHUNKS + WHERE workspace = $2 + AND to_tsvector('{language}', content) @@ plainto_tsquery('{language}', $1) + ORDER BY score DESC + LIMIT $3 + """ + + results = await self.query(sql, [query, ws, limit], multirows=True) + return results if results else [] + class ClientManager: _instances: ClassVar[dict[str, Any]] = {'db': None, 'ref_count': 0} @@ -2105,6 +2233,9 @@ class PGKVStorage(BaseKVStorage): v['full_doc_id'], v['content'], v['file_path'], + v.get('s3_key'), # S3 key for document source + v.get('char_start'), # Character offset start in source document + v.get('char_end'), # Character offset end in source document json.dumps(v.get('llm_cache_list', [])), current_time, current_time, @@ -2121,6 +2252,7 @@ class PGKVStorage(BaseKVStorage): v['content'], v.get('file_path', ''), # Map file_path to doc_name self.workspace, + v.get('s3_key'), # S3 key for document source ) for k, v in data.items() ] @@ -2267,6 +2399,58 @@ class PGKVStorage(BaseKVStorage): except Exception as e: return {'status': 'error', 'message': str(e)} + async def update_s3_key_by_doc_id( + self, full_doc_id: str, s3_key: str, archive_url: str | None = None + ) -> int: + """Update s3_key for all chunks of a document after archiving. + + This method is called after a document is moved from S3 staging to archive, + to update the database chunks with the new archive location. + + Args: + full_doc_id: Document ID to update + s3_key: Archive S3 key (e.g., 'archive/default/doc123/file.pdf') + archive_url: Optional full S3 URL to update file_path + + Returns: + Number of rows updated + """ + if archive_url: + # Update both s3_key and file_path + sql = """ + UPDATE LIGHTRAG_DOC_CHUNKS + SET s3_key = $1, file_path = $2, update_time = CURRENT_TIMESTAMP + WHERE workspace = $3 AND full_doc_id = $4 + """ + params = { + 's3_key': s3_key, + 'file_path': archive_url, + 'workspace': self.workspace, + 'full_doc_id': full_doc_id, + } + else: + # Update only s3_key + sql = """ + UPDATE LIGHTRAG_DOC_CHUNKS + SET s3_key = $1, update_time = CURRENT_TIMESTAMP + WHERE workspace = $2 AND full_doc_id = $3 + """ + params = { + 's3_key': s3_key, + 'workspace': self.workspace, + 'full_doc_id': full_doc_id, + } + + result = await self.db.execute(sql, params) + + # Parse the number of rows updated from result like "UPDATE 5" + try: + count = int(result.split()[-1]) if result else 0 + except (ValueError, AttributeError, IndexError): + count = 0 + logger.debug(f'[{self.workspace}] Updated {count} chunks with s3_key for doc {full_doc_id}') + return count + @final @dataclass @@ -5026,6 +5210,7 @@ TABLES = { doc_name VARCHAR(1024), content TEXT, meta JSONB, + s3_key TEXT NULL, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) @@ -5040,6 +5225,9 @@ TABLES = { tokens INTEGER, content TEXT, file_path TEXT NULL, + s3_key TEXT NULL, + char_start INTEGER NULL, + char_end INTEGER NULL, llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, @@ -5185,11 +5373,12 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage 'get_by_id_full_docs': """SELECT id, COALESCE(content, '') as content, - COALESCE(doc_name, '') as file_path + COALESCE(doc_name, '') as file_path, s3_key FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 """, 'get_by_id_text_chunks': """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path, + chunk_order_index, full_doc_id, file_path, s3_key, + char_start, char_end, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time @@ -5201,11 +5390,12 @@ SQL_TEMPLATES = { FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 """, 'get_by_ids_full_docs': """SELECT id, COALESCE(content, '') as content, - COALESCE(doc_name, '') as file_path + COALESCE(doc_name, '') as file_path, s3_key FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2) """, 'get_by_ids_text_chunks': """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path, + chunk_order_index, full_doc_id, file_path, s3_key, + char_start, char_end, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time @@ -5257,11 +5447,12 @@ SQL_TEMPLATES = { FROM LIGHTRAG_RELATION_CHUNKS WHERE workspace=$1 AND id = ANY($2) """, 'filter_keys': 'SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})', - 'upsert_doc_full': """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) - VALUES ($1, $2, $3, $4) + 'upsert_doc_full': """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace, s3_key) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (workspace,id) DO UPDATE SET content = $2, doc_name = $3, + s3_key = COALESCE($5, LIGHTRAG_DOC_FULL.s3_key), update_time = CURRENT_TIMESTAMP """, 'upsert_llm_response_cache': """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam) @@ -5275,15 +5466,18 @@ SQL_TEMPLATES = { update_time = CURRENT_TIMESTAMP """, 'upsert_text_chunk': """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, file_path, llm_cache_list, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + chunk_order_index, full_doc_id, content, file_path, s3_key, + char_start, char_end, llm_cache_list, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, full_doc_id=EXCLUDED.full_doc_id, content = EXCLUDED.content, file_path=EXCLUDED.file_path, + s3_key=COALESCE(EXCLUDED.s3_key, LIGHTRAG_DOC_CHUNKS.s3_key), + char_start=EXCLUDED.char_start, + char_end=EXCLUDED.char_end, llm_cache_list=EXCLUDED.llm_cache_list, update_time = EXCLUDED.update_time """, diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 59677f10..211143d8 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -115,7 +115,8 @@ class RedisConnectionManager: except RuntimeError: asyncio.run(cls.release_pool_async(redis_url)) return - loop.create_task(cls.release_pool_async(redis_url)) + task = loop.create_task(cls.release_pool_async(redis_url)) + return task @classmethod def close_all_pools(cls): diff --git a/lightrag/llm/binding_options.py b/lightrag/llm/binding_options.py index 7c01abb7..e10b63ff 100644 --- a/lightrag/llm/binding_options.py +++ b/lightrag/llm/binding_options.py @@ -285,7 +285,7 @@ class BindingOptions: sample_stream.write(f'# {arg_item["help"]}\n') # Handle JSON formatting for list and dict types - if arg_item['type'] == list[str] or arg_item['type'] == dict: + if arg_item['type'] is list[str] or arg_item['type'] is dict: default_value = json.dumps(arg_item['default']) else: default_value = arg_item['default'] diff --git a/lightrag/llm/deprecated/siliconcloud.py b/lightrag/llm/deprecated/siliconcloud.py index 9623fd0a..c1b5afa8 100644 --- a/lightrag/llm/deprecated/siliconcloud.py +++ b/lightrag/llm/deprecated/siliconcloud.py @@ -9,7 +9,6 @@ import struct import aiohttp import numpy as np -from lightrag.utils import logger from openai import ( APIConnectionError, APITimeoutError, @@ -22,6 +21,8 @@ from tenacity import ( wait_exponential, ) +from lightrag.utils import logger + @retry( stop=stop_after_attempt(3), diff --git a/lightrag/llm/llama_index_impl.py b/lightrag/llm/llama_index_impl.py index 502b7e3d..f83f896d 100644 --- a/lightrag/llm/llama_index_impl.py +++ b/lightrag/llm/llama_index_impl.py @@ -1,4 +1,5 @@ from typing import Any + import pipmaster as pm from llama_index.core.llms import ( ChatMessage, diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py index 877f5476..e1e60c0b 100644 --- a/lightrag/llm/nvidia_openai.py +++ b/lightrag/llm/nvidia_openai.py @@ -7,6 +7,7 @@ if not pm.is_installed('openai'): pm.install('openai') from typing import Literal + import numpy as np from openai import ( APIConnectionError, diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py index 9cc05e62..6dff14e8 100644 --- a/lightrag/llm/zhipu.py +++ b/lightrag/llm/zhipu.py @@ -1,10 +1,10 @@ import json import re -from lightrag.utils import verbose_debug - import pipmaster as pm # Pipmaster for dynamic library install +from lightrag.utils import verbose_debug + # install specific modules if not pm.is_installed('zhipuai'): pm.install('zhipuai') diff --git a/lightrag/operate.py b/lightrag/operate.py index 1f146466..781cfe96 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -7,10 +7,10 @@ import os import re import time from collections import Counter, defaultdict -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable from functools import partial from pathlib import Path -from typing import Any, Awaitable, Callable, Literal, cast, overload +from typing import Any, cast import json_repair from dotenv import load_dotenv @@ -31,9 +31,9 @@ from lightrag.constants import ( DEFAULT_KG_CHUNK_PICK_METHOD, DEFAULT_MAX_ENTITY_TOKENS, DEFAULT_MAX_FILE_PATHS, + DEFAULT_MAX_RELATION_TOKENS, DEFAULT_MAX_SOURCE_IDS_PER_ENTITY, DEFAULT_MAX_SOURCE_IDS_PER_RELATION, - DEFAULT_MAX_RELATION_TOKENS, DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_RELATED_CHUNK_NUMBER, DEFAULT_SOURCE_IDS_LIMIT_METHOD, @@ -231,11 +231,19 @@ def chunking_by_token_size( chunk_overlap_token_size: int = 100, chunk_token_size: int = 1200, ) -> list[dict[str, Any]]: + """Split content into chunks by token size, tracking character positions. + + Returns chunks with char_start and char_end for citation support. + """ tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) - new_chunks = [] + # Track character positions: (tokens, chunk_text, char_start, char_end) + new_chunks: list[tuple[int, str, int, int]] = [] + char_position = 0 + separator_len = len(split_by_character) + if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) @@ -250,32 +258,58 @@ def chunking_by_token_size( chunk_token_limit=chunk_token_size, chunk_preview=chunk[:120], ) - new_chunks.append((len(_tokens), chunk)) + chunk_start = char_position + chunk_end = char_position + len(chunk) + new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end)) + char_position = chunk_end + separator_len # Skip separator else: for chunk in raw_chunks: + chunk_start = char_position _tokens = tokenizer.encode(chunk) if len(_tokens) > chunk_token_size: + # Sub-chunking: approximate char positions within the chunk + sub_char_position = 0 for start in range(0, len(_tokens), chunk_token_size - chunk_overlap_token_size): chunk_content = tokenizer.decode(_tokens[start : start + chunk_token_size]) - new_chunks.append((min(chunk_token_size, len(_tokens) - start), chunk_content)) + # Approximate char position based on content length ratio + sub_start = chunk_start + sub_char_position + sub_end = sub_start + len(chunk_content) + new_chunks.append( + (min(chunk_token_size, len(_tokens) - start), chunk_content, sub_start, sub_end) + ) + sub_char_position += len(chunk_content) - (chunk_overlap_token_size * 4) # Approx overlap else: - new_chunks.append((len(_tokens), chunk)) - for index, (_len, chunk) in enumerate(new_chunks): + chunk_end = chunk_start + len(chunk) + new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end)) + char_position = chunk_start + len(chunk) + separator_len + + for index, (_len, chunk, char_start, char_end) in enumerate(new_chunks): results.append( { 'tokens': _len, 'content': chunk.strip(), 'chunk_order_index': index, + 'char_start': char_start, + 'char_end': char_end, } ) else: + # Token-based chunking: track character positions through decoded content + char_position = 0 for index, start in enumerate(range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)): chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size]) + # For overlapping chunks, approximate positions based on previous chunk + char_start = 0 if index == 0 else char_position + char_end = char_start + len(chunk_content) + char_position = char_start + len(chunk_content) - (chunk_overlap_token_size * 4) # Approx char overlap + results.append( { 'tokens': min(chunk_token_size, len(tokens) - start), 'content': chunk_content.strip(), 'chunk_order_index': index, + 'char_start': char_start, + 'char_end': char_end, } ) return results diff --git a/lightrag/storage/__init__.py b/lightrag/storage/__init__.py new file mode 100644 index 00000000..9482bdc7 --- /dev/null +++ b/lightrag/storage/__init__.py @@ -0,0 +1,5 @@ +"""Storage module for S3/object storage integration.""" + +from lightrag.storage.s3_client import S3Client, S3Config + +__all__ = ["S3Client", "S3Config"] diff --git a/lightrag/storage/s3_client.py b/lightrag/storage/s3_client.py new file mode 100644 index 00000000..8ca738ed --- /dev/null +++ b/lightrag/storage/s3_client.py @@ -0,0 +1,390 @@ +""" +Async S3 client wrapper for RustFS/MinIO/AWS S3 compatible object storage. + +This module provides staging and archive functionality for documents: +- Upload to staging: s3://bucket/staging/{workspace}/{doc_id} +- Move to archive: s3://bucket/archive/{workspace}/{doc_id} +- Generate presigned URLs for citations +""" + +import hashlib +import logging +import os +import threading +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, ClassVar + +import pipmaster as pm + +if not pm.is_installed("aioboto3"): + pm.install("aioboto3") + +import aioboto3 +from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from lightrag.utils import logger + +# Constants with environment variable support +S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL", "") +S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID", "") +S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY", "") +S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "lightrag") +S3_REGION = os.getenv("S3_REGION", "us-east-1") +S3_RETRY_ATTEMPTS = int(os.getenv("S3_RETRY_ATTEMPTS", "3")) +S3_CONNECT_TIMEOUT = int(os.getenv("S3_CONNECT_TIMEOUT", "10")) +S3_READ_TIMEOUT = int(os.getenv("S3_READ_TIMEOUT", "30")) +S3_PRESIGNED_URL_EXPIRY = int(os.getenv("S3_PRESIGNED_URL_EXPIRY", "3600")) # 1 hour + + +# Retry decorator for S3 operations +s3_retry = retry( + stop=stop_after_attempt(S3_RETRY_ATTEMPTS), + wait=wait_exponential(multiplier=1, min=1, max=8), + retry=retry_if_exception_type(ClientError), + before_sleep=before_sleep_log(logger, logging.WARNING), +) + + +@dataclass +class S3Config: + """Configuration for S3 client.""" + + endpoint_url: str = field(default_factory=lambda: S3_ENDPOINT_URL) + access_key_id: str = field(default_factory=lambda: S3_ACCESS_KEY_ID) + secret_access_key: str = field(default_factory=lambda: S3_SECRET_ACCESS_KEY) + bucket_name: str = field(default_factory=lambda: S3_BUCKET_NAME) + region: str = field(default_factory=lambda: S3_REGION) + connect_timeout: int = field(default_factory=lambda: S3_CONNECT_TIMEOUT) + read_timeout: int = field(default_factory=lambda: S3_READ_TIMEOUT) + presigned_url_expiry: int = field(default_factory=lambda: S3_PRESIGNED_URL_EXPIRY) + + def __post_init__(self): + if not self.access_key_id or not self.secret_access_key: + raise ValueError( + "S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY must be set" + ) + + +class S3ClientManager: + """Shared S3 session manager to avoid creating multiple sessions.""" + + _sessions: ClassVar[dict[str, aioboto3.Session]] = {} + _session_refs: ClassVar[dict[str, int]] = {} + _lock: ClassVar[threading.Lock] = threading.Lock() + + @classmethod + def get_session(cls, config: S3Config) -> aioboto3.Session: + """Get or create a session for the given S3 config.""" + # Use endpoint + access_key as session key + session_key = f"{config.endpoint_url}:{config.access_key_id}" + + with cls._lock: + if session_key not in cls._sessions: + cls._sessions[session_key] = aioboto3.Session( + aws_access_key_id=config.access_key_id, + aws_secret_access_key=config.secret_access_key, + region_name=config.region, + ) + cls._session_refs[session_key] = 0 + logger.info(f"Created shared S3 session for {config.endpoint_url}") + + cls._session_refs[session_key] += 1 + logger.debug( + f"S3 session {session_key} reference count: {cls._session_refs[session_key]}" + ) + + return cls._sessions[session_key] + + @classmethod + def release_session(cls, config: S3Config): + """Release a reference to the session.""" + session_key = f"{config.endpoint_url}:{config.access_key_id}" + + with cls._lock: + if session_key in cls._session_refs: + cls._session_refs[session_key] -= 1 + logger.debug( + f"S3 session {session_key} reference count: {cls._session_refs[session_key]}" + ) + + +@dataclass +class S3Client: + """ + Async S3 client for document staging and archival. + + Usage: + config = S3Config() + client = S3Client(config) + await client.initialize() + + # Upload to staging + s3_key = await client.upload_to_staging(workspace, doc_id, content, filename) + + # Move to archive after processing + archive_key = await client.move_to_archive(s3_key) + + # Get presigned URL for citations + url = await client.get_presigned_url(archive_key) + + await client.finalize() + """ + + config: S3Config + _session: aioboto3.Session = field(default=None, init=False, repr=False) + _initialized: bool = field(default=False, init=False, repr=False) + + async def initialize(self): + """Initialize the S3 client.""" + if self._initialized: + return + + self._session = S3ClientManager.get_session(self.config) + + # Ensure bucket exists + await self._ensure_bucket_exists() + + self._initialized = True + logger.info(f"S3 client initialized for bucket: {self.config.bucket_name}") + + async def finalize(self): + """Release resources.""" + if self._initialized: + S3ClientManager.release_session(self.config) + self._initialized = False + logger.info("S3 client finalized") + + @asynccontextmanager + async def _get_client(self): + """Get an S3 client from the session.""" + boto_config = BotoConfig( + connect_timeout=self.config.connect_timeout, + read_timeout=self.config.read_timeout, + retries={"max_attempts": S3_RETRY_ATTEMPTS}, + ) + + async with self._session.client( + "s3", + endpoint_url=self.config.endpoint_url if self.config.endpoint_url else None, + config=boto_config, + ) as client: + yield client + + async def _ensure_bucket_exists(self): + """Create bucket if it doesn't exist.""" + async with self._get_client() as client: + try: + await client.head_bucket(Bucket=self.config.bucket_name) + logger.debug(f"Bucket {self.config.bucket_name} exists") + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code in ("404", "NoSuchBucket"): + logger.info(f"Creating bucket: {self.config.bucket_name}") + await client.create_bucket(Bucket=self.config.bucket_name) + else: + raise + + def _make_staging_key(self, workspace: str, doc_id: str, filename: str) -> str: + """Generate S3 key for staging area.""" + safe_filename = filename.replace("/", "_").replace("\\", "_") + return f"staging/{workspace}/{doc_id}/{safe_filename}" + + def _make_archive_key(self, workspace: str, doc_id: str, filename: str) -> str: + """Generate S3 key for archive area.""" + safe_filename = filename.replace("/", "_").replace("\\", "_") + return f"archive/{workspace}/{doc_id}/{safe_filename}" + + def _staging_to_archive_key(self, staging_key: str) -> str: + """Convert staging key to archive key.""" + if staging_key.startswith("staging/"): + return "archive/" + staging_key[8:] + return staging_key + + @s3_retry + async def upload_to_staging( + self, + workspace: str, + doc_id: str, + content: bytes | str, + filename: str, + content_type: str = "application/octet-stream", + metadata: dict[str, str] | None = None, + ) -> str: + """ + Upload document to staging area. + + Args: + workspace: Workspace/tenant identifier + doc_id: Document ID + content: File content (bytes or string) + filename: Original filename + content_type: MIME type + metadata: Optional metadata dict + + Returns: + S3 key for the uploaded object + """ + s3_key = self._make_staging_key(workspace, doc_id, filename) + + if isinstance(content, str): + content = content.encode("utf-8") + + # Calculate content hash for deduplication + content_hash = hashlib.sha256(content).hexdigest() + + upload_metadata = { + "workspace": workspace, + "doc_id": doc_id, + "original_filename": filename, + "content_hash": content_hash, + **(metadata or {}), + } + + async with self._get_client() as client: + await client.put_object( + Bucket=self.config.bucket_name, + Key=s3_key, + Body=content, + ContentType=content_type, + Metadata=upload_metadata, + ) + + logger.info(f"Uploaded to staging: {s3_key} ({len(content)} bytes)") + return s3_key + + @s3_retry + async def get_object(self, s3_key: str) -> tuple[bytes, dict[str, Any]]: + """ + Get object content and metadata. + + Returns: + Tuple of (content_bytes, metadata_dict) + """ + async with self._get_client() as client: + response = await client.get_object( + Bucket=self.config.bucket_name, + Key=s3_key, + ) + content = await response["Body"].read() + metadata = response.get("Metadata", {}) + + logger.debug(f"Retrieved object: {s3_key} ({len(content)} bytes)") + return content, metadata + + @s3_retry + async def move_to_archive(self, staging_key: str) -> str: + """ + Move object from staging to archive. + + Args: + staging_key: Current S3 key in staging/ + + Returns: + New S3 key in archive/ + """ + archive_key = self._staging_to_archive_key(staging_key) + + async with self._get_client() as client: + # Copy to archive + await client.copy_object( + Bucket=self.config.bucket_name, + CopySource={"Bucket": self.config.bucket_name, "Key": staging_key}, + Key=archive_key, + ) + + # Delete from staging + await client.delete_object( + Bucket=self.config.bucket_name, + Key=staging_key, + ) + + logger.info(f"Moved to archive: {staging_key} -> {archive_key}") + return archive_key + + @s3_retry + async def delete_object(self, s3_key: str): + """Delete an object.""" + async with self._get_client() as client: + await client.delete_object( + Bucket=self.config.bucket_name, + Key=s3_key, + ) + logger.info(f"Deleted object: {s3_key}") + + @s3_retry + async def list_staging(self, workspace: str) -> list[dict[str, Any]]: + """ + List all objects in staging for a workspace. + + Returns: + List of dicts with key, size, last_modified + """ + prefix = f"staging/{workspace}/" + objects = [] + + async with self._get_client() as client: + paginator = client.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.config.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + objects.append( + { + "key": obj["Key"], + "size": obj["Size"], + "last_modified": obj["LastModified"].isoformat(), + } + ) + + return objects + + async def get_presigned_url( + self, s3_key: str, expiry: int | None = None + ) -> str: + """ + Generate a presigned URL for direct access. + + Args: + s3_key: S3 object key + expiry: URL expiry in seconds (default from config) + + Returns: + Presigned URL string + """ + expiry = expiry or self.config.presigned_url_expiry + + async with self._get_client() as client: + url = await client.generate_presigned_url( + "get_object", + Params={"Bucket": self.config.bucket_name, "Key": s3_key}, + ExpiresIn=expiry, + ) + + return url + + async def object_exists(self, s3_key: str) -> bool: + """Check if an object exists.""" + async with self._get_client() as client: + try: + await client.head_object( + Bucket=self.config.bucket_name, + Key=s3_key, + ) + return True + except ClientError as e: + if e.response.get("Error", {}).get("Code") == "404": + return False + raise + + def get_s3_url(self, s3_key: str) -> str: + """Get the S3 URL for an object (not presigned, for reference).""" + return f"s3://{self.config.bucket_name}/{s3_key}" diff --git a/lightrag/tools/lightrag_visualizer/graph_visualizer.py b/lightrag/tools/lightrag_visualizer/graph_visualizer.py index 433c5ed5..8fe2053b 100644 --- a/lightrag/tools/lightrag_visualizer/graph_visualizer.py +++ b/lightrag/tools/lightrag_visualizer/graph_visualizer.py @@ -1,6 +1,3 @@ -import networkx as nx -import numpy as np - import colorsys import os import tkinter as tk @@ -10,6 +7,8 @@ from tkinter import filedialog import community import glm import moderngl +import networkx as nx +import numpy as np from imgui_bundle import hello_imgui, imgui, immapp CUSTOM_FONT = 'font.ttf' diff --git a/lightrag/utils.py b/lightrag/utils.py index bf9ee415..d3040152 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -13,7 +13,7 @@ import sys import time import uuid import weakref -from collections.abc import Callable, Collection, Iterable, Sequence +from collections.abc import Awaitable, Callable, Collection, Iterable, Sequence from dataclasses import dataclass from datetime import datetime from functools import wraps @@ -21,7 +21,6 @@ from hashlib import md5 from typing import ( TYPE_CHECKING, Any, - Awaitable, Protocol, cast, ) @@ -3044,6 +3043,33 @@ def convert_to_user_format( } +def _extract_document_title(file_path: str) -> str: + """Extract document title from file_path (just the filename without directory).""" + if not file_path: + return '' + # Handle S3 URLs (s3://bucket/path/file.pdf) and regular paths + if file_path.startswith('s3://'): + # Remove s3://bucket/ prefix and get filename + parts = file_path.split('/') + return parts[-1] if parts else '' + # Handle regular file paths + import os + + return os.path.basename(file_path) + + +def _generate_excerpt(content: str, max_length: int = 150) -> str: + """Generate a brief excerpt from content.""" + if not content: + return '' + # Strip whitespace and truncate + excerpt = content.strip()[:max_length] + # Add ellipsis if truncated + if len(content.strip()) > max_length: + excerpt = excerpt.rstrip() + '...' + return excerpt + + def generate_reference_list_from_chunks( chunks: list[dict], ) -> tuple[list[dict], list[dict]]: @@ -3052,30 +3078,44 @@ def generate_reference_list_from_chunks( This function extracts file_paths from chunks, counts their occurrences, sorts by frequency and first appearance order, creates reference_id mappings, - and builds a reference_list structure. + and builds a reference_list structure with enhanced metadata. + + Enhanced fields include: + - document_title: Human-readable filename extracted from file_path + - s3_key: S3 object key if available + - excerpt: First 150 chars from the first chunk of each document Args: - chunks: List of chunk dictionaries with file_path information + chunks: List of chunk dictionaries with file_path, s3_key, content, + and optional char_start/char_end information Returns: tuple: (reference_list, updated_chunks_with_reference_ids) - - reference_list: List of dicts with reference_id and file_path - - updated_chunks_with_reference_ids: Original chunks with reference_id field added + - reference_list: List of dicts with reference_id, file_path, and enhanced fields + - updated_chunks_with_reference_ids: Original chunks with reference_id and excerpt fields """ if not chunks: return [], [] # 1. Extract all valid file_paths and count their occurrences - file_path_counts = {} + # Also collect s3_key and first chunk content for each file_path + file_path_counts: dict[str, int] = {} + file_path_metadata: dict[str, dict] = {} # file_path -> {s3_key, first_excerpt} + for chunk in chunks: file_path = chunk.get('file_path', '') if file_path and file_path != 'unknown_source': file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1 + # Collect metadata from first chunk for each file_path + if file_path not in file_path_metadata: + file_path_metadata[file_path] = { + 's3_key': chunk.get('s3_key'), + 'first_excerpt': _generate_excerpt(chunk.get('content', '')), + } # 2. Sort file paths by frequency (descending), then by first appearance order - # Create a list of (file_path, count, first_index) tuples file_path_with_indices = [] - seen_paths = set() + seen_paths: set[str] = set() for i, chunk in enumerate(chunks): file_path = chunk.get('file_path', '') if file_path and file_path != 'unknown_source' and file_path not in seen_paths: @@ -3091,7 +3131,7 @@ def generate_reference_list_from_chunks( for i, file_path in enumerate(unique_file_paths): file_path_to_ref_id[file_path] = str(i + 1) - # 4. Add reference_id field to each chunk + # 4. Add reference_id and excerpt fields to each chunk updated_chunks = [] for chunk in chunks: chunk_copy = chunk.copy() @@ -3100,11 +3140,20 @@ def generate_reference_list_from_chunks( chunk_copy['reference_id'] = file_path_to_ref_id[file_path] else: chunk_copy['reference_id'] = '' + # Add excerpt from this chunk's content + chunk_copy['excerpt'] = _generate_excerpt(chunk_copy.get('content', '')) updated_chunks.append(chunk_copy) - # 5. Build reference_list + # 5. Build enhanced reference_list reference_list = [] for i, file_path in enumerate(unique_file_paths): - reference_list.append({'reference_id': str(i + 1), 'file_path': file_path}) + metadata = file_path_metadata.get(file_path, {}) + reference_list.append({ + 'reference_id': str(i + 1), + 'file_path': file_path, + 'document_title': _extract_document_title(file_path), + 's3_key': metadata.get('s3_key'), + 'excerpt': metadata.get('first_excerpt', ''), + }) return reference_list, updated_chunks diff --git a/pyproject.toml b/pyproject.toml index da493909..ec024c82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,9 @@ test = [ "pytest-asyncio>=1.2.0", "pre-commit", "ruff", + "moto[s3]>=5.0", # AWS S3 mocking for S3 client tests + "httpx>=0.27", # Async HTTP client for FastAPI testing + "aioboto3>=12.0.0,<16.0.0", # Required for S3 client tests ] # Type-checking/lint extras diff --git a/reproduce/batch_eval.py b/reproduce/batch_eval.py index f7e1c2e8..d9d8fc26 100644 --- a/reproduce/batch_eval.py +++ b/reproduce/batch_eval.py @@ -1,6 +1,5 @@ import json import logging -import os import re from pathlib import Path diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 57e7158b..cdaac9aa 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -1056,3 +1056,219 @@ def test_decode_preserves_content(): tokens = tokenizer.encode(original) decoded = tokenizer.decode(tokens) assert decoded == original, f'Failed to decode: {original}' + + +# ============================================================================ +# Character Position Tests (char_start, char_end for citations) +# ============================================================================ + + +@pytest.mark.offline +def test_char_positions_present(): + """Verify char_start and char_end are present in all chunks.""" + tokenizer = make_tokenizer() + content = 'alpha\n\nbeta\n\ngamma' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=20, + ) + + for chunk in chunks: + assert 'char_start' in chunk, 'char_start field missing' + assert 'char_end' in chunk, 'char_end field missing' + assert isinstance(chunk['char_start'], int), 'char_start should be int' + assert isinstance(chunk['char_end'], int), 'char_end should be int' + + +@pytest.mark.offline +def test_char_positions_basic_delimiter_split(): + """Test char_start/char_end with basic delimiter splitting.""" + tokenizer = make_tokenizer() + # "alpha\n\nbeta" = positions: alpha at 0-5, beta at 7-11 + content = 'alpha\n\nbeta' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=20, + ) + + assert len(chunks) == 2 + # First chunk "alpha" starts at 0 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == 5 + # Second chunk "beta" starts after "\n\n" (position 7) + assert chunks[1]['char_start'] == 7 + assert chunks[1]['char_end'] == 11 + + +@pytest.mark.offline +def test_char_positions_single_chunk(): + """Test char_start/char_end for content that fits in single chunk.""" + tokenizer = make_tokenizer() + content = 'hello world' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=50, + ) + + assert len(chunks) == 1 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == len(content) + + +@pytest.mark.offline +def test_char_positions_token_based_no_overlap(): + """Test char_start/char_end with token-based chunking, no overlap.""" + tokenizer = make_tokenizer() + content = '0123456789abcdefghij' # 20 chars + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character=None, + split_by_character_only=False, + chunk_token_size=10, + chunk_overlap_token_size=0, + ) + + assert len(chunks) == 2 + # First chunk: chars 0-10 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == 10 + # Second chunk: chars 10-20 + assert chunks[1]['char_start'] == 10 + assert chunks[1]['char_end'] == 20 + + +@pytest.mark.offline +def test_char_positions_consecutive_delimiters(): + """Test char positions with multiple delimiter-separated chunks.""" + tokenizer = make_tokenizer() + # "a||b||c" with delimiter "||" + content = 'first||second||third' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='||', + split_by_character_only=True, + chunk_token_size=50, + ) + + assert len(chunks) == 3 + # "first" at 0-5 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == 5 + # "second" at 7-13 (after "||") + assert chunks[1]['char_start'] == 7 + assert chunks[1]['char_end'] == 13 + # "third" at 15-20 + assert chunks[2]['char_start'] == 15 + assert chunks[2]['char_end'] == 20 + + +@pytest.mark.offline +def test_char_positions_unicode(): + """Test char_start/char_end with unicode content.""" + tokenizer = make_tokenizer() + # Unicode characters should count as individual chars + content = '日本語\n\nテスト' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=50, + ) + + assert len(chunks) == 2 + # "日本語" = 3 characters, positions 0-3 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == 3 + # "テスト" starts at position 5 (after \n\n) + assert chunks[1]['char_start'] == 5 + assert chunks[1]['char_end'] == 8 + + +@pytest.mark.offline +def test_char_positions_empty_content(): + """Test char_start/char_end with empty content.""" + tokenizer = make_tokenizer() + + chunks = chunking_by_token_size( + tokenizer, + '', + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=10, + ) + + assert len(chunks) == 1 + assert chunks[0]['char_start'] == 0 + assert chunks[0]['char_end'] == 0 + + +@pytest.mark.offline +def test_char_positions_verify_content_match(): + """Verify that char_start/char_end can be used to extract original content.""" + tokenizer = make_tokenizer() + content = 'The quick\n\nbrown fox\n\njumps over' + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character='\n\n', + split_by_character_only=True, + chunk_token_size=50, + ) + + for chunk in chunks: + # Extract using char positions and compare (stripping whitespace) + extracted = content[chunk['char_start'] : chunk['char_end']].strip() + assert extracted == chunk['content'], f"Mismatch: '{extracted}' != '{chunk['content']}'" + + +@pytest.mark.offline +def test_char_positions_with_overlap_approximation(): + """Test char positions with overlapping chunks (positions are approximate). + + Note: The overlap approximation uses `chunk_overlap_token_size * 4` to estimate + character overlap. This can result in negative char_start for later chunks + when overlap is large relative to chunk size. This is expected behavior + for the approximation algorithm. + """ + tokenizer = make_tokenizer() + content = '0123456789abcdefghij' # 20 chars + + chunks = chunking_by_token_size( + tokenizer, + content, + split_by_character=None, + split_by_character_only=False, + chunk_token_size=10, + chunk_overlap_token_size=3, + ) + + # With overlap=3, step=7: chunks at 0, 7, 14 + assert len(chunks) == 3 + # First chunk always starts at 0 + assert chunks[0]['char_start'] == 0 + # char_start and char_end are integers (approximate positions) + for chunk in chunks: + assert isinstance(chunk['char_start'], int) + assert isinstance(chunk['char_end'], int) + # char_end should always be greater than char_start + for chunk in chunks: + assert chunk['char_end'] > chunk['char_start'], f"char_end ({chunk['char_end']}) should be > char_start ({chunk['char_start']})" diff --git a/tests/test_citation_utils.py b/tests/test_citation_utils.py new file mode 100644 index 00000000..486c016c --- /dev/null +++ b/tests/test_citation_utils.py @@ -0,0 +1,352 @@ +"""Tests for citation utility functions in lightrag/utils.py. + +This module tests the helper functions used for generating citations +and reference lists from document chunks. +""" + +import pytest + +from lightrag.utils import ( + _extract_document_title, + _generate_excerpt, + generate_reference_list_from_chunks, +) + + +# ============================================================================ +# Tests for _extract_document_title() +# ============================================================================ + + +class TestExtractDocumentTitle: + """Tests for _extract_document_title function.""" + + @pytest.mark.offline + def test_regular_path(self): + """Test extracting title from regular file path.""" + assert _extract_document_title('/path/to/document.pdf') == 'document.pdf' + + @pytest.mark.offline + def test_nested_path(self): + """Test extracting title from deeply nested path.""" + assert _extract_document_title('/a/b/c/d/e/report.docx') == 'report.docx' + + @pytest.mark.offline + def test_s3_path(self): + """Test extracting title from S3 URL.""" + assert _extract_document_title('s3://bucket/archive/default/doc123/report.pdf') == 'report.pdf' + + @pytest.mark.offline + def test_s3_path_simple(self): + """Test extracting title from simple S3 URL.""" + assert _extract_document_title('s3://mybucket/file.txt') == 'file.txt' + + @pytest.mark.offline + def test_empty_string(self): + """Test with empty string returns empty.""" + assert _extract_document_title('') == '' + + @pytest.mark.offline + def test_trailing_slash(self): + """Test path with trailing slash returns empty.""" + assert _extract_document_title('/path/to/dir/') == '' + + @pytest.mark.offline + def test_filename_only(self): + """Test with just a filename (no path).""" + assert _extract_document_title('document.pdf') == 'document.pdf' + + @pytest.mark.offline + def test_no_extension(self): + """Test filename without extension.""" + assert _extract_document_title('/path/to/README') == 'README' + + @pytest.mark.offline + def test_windows_style_path(self): + """Test Windows-style path (backslashes).""" + # os.path.basename handles this correctly on Unix + result = _extract_document_title('C:\\Users\\docs\\file.pdf') + # On Unix, this returns the whole string as basename doesn't split on backslash + assert 'file.pdf' in result or result == 'C:\\Users\\docs\\file.pdf' + + @pytest.mark.offline + def test_special_characters(self): + """Test filename with special characters.""" + assert _extract_document_title('/path/to/my file (1).pdf') == 'my file (1).pdf' + + @pytest.mark.offline + def test_unicode_filename(self): + """Test filename with unicode characters.""" + assert _extract_document_title('/path/to/文档.pdf') == '文档.pdf' + + +# ============================================================================ +# Tests for _generate_excerpt() +# ============================================================================ + + +class TestGenerateExcerpt: + """Tests for _generate_excerpt function.""" + + @pytest.mark.offline + def test_short_content(self): + """Test content shorter than max_length.""" + assert _generate_excerpt('Hello world') == 'Hello world' + + @pytest.mark.offline + def test_exact_length(self): + """Test content exactly at max_length.""" + content = 'a' * 150 + result = _generate_excerpt(content, max_length=150) + assert result == content # No ellipsis for exact length + + @pytest.mark.offline + def test_long_content_truncated(self): + """Test long content is truncated with ellipsis.""" + content = 'a' * 200 + result = _generate_excerpt(content, max_length=150) + assert len(result) == 153 # 150 chars + '...' + assert result.endswith('...') + + @pytest.mark.offline + def test_empty_string(self): + """Test empty string returns empty.""" + assert _generate_excerpt('') == '' + + @pytest.mark.offline + def test_whitespace_stripped(self): + """Test leading/trailing whitespace is stripped.""" + assert _generate_excerpt(' hello world ') == 'hello world' + + @pytest.mark.offline + def test_whitespace_only(self): + """Test whitespace-only content returns empty.""" + assert _generate_excerpt(' \n\t ') == '' + + @pytest.mark.offline + def test_custom_max_length(self): + """Test custom max_length parameter.""" + content = 'This is a test sentence for excerpts.' + result = _generate_excerpt(content, max_length=10) + # Note: rstrip() removes trailing space before adding ellipsis + assert result == 'This is a...' + + @pytest.mark.offline + def test_unicode_content(self): + """Test unicode content handling.""" + content = '日本語テキスト' * 50 # 350 chars + result = _generate_excerpt(content, max_length=150) + assert len(result) == 153 # 150 chars + '...' + + @pytest.mark.offline + def test_newlines_preserved(self): + """Test that newlines within content are preserved.""" + content = 'Line 1\nLine 2' + result = _generate_excerpt(content) + assert result == 'Line 1\nLine 2' + + @pytest.mark.offline + def test_very_short_max_length(self): + """Test with very short max_length.""" + result = _generate_excerpt('Hello world', max_length=5) + assert result == 'Hello...' + + +# ============================================================================ +# Tests for generate_reference_list_from_chunks() +# ============================================================================ + + +class TestGenerateReferenceListFromChunks: + """Tests for generate_reference_list_from_chunks function.""" + + @pytest.mark.offline + def test_empty_chunks(self): + """Test with empty chunk list.""" + ref_list, updated_chunks = generate_reference_list_from_chunks([]) + assert ref_list == [] + assert updated_chunks == [] + + @pytest.mark.offline + def test_single_chunk(self): + """Test with a single chunk.""" + chunks = [ + { + 'file_path': '/path/to/doc.pdf', + 'content': 'This is the content.', + 's3_key': 'archive/doc.pdf', + } + ] + ref_list, updated_chunks = generate_reference_list_from_chunks(chunks) + + assert len(ref_list) == 1 + assert ref_list[0]['reference_id'] == '1' + assert ref_list[0]['file_path'] == '/path/to/doc.pdf' + assert ref_list[0]['document_title'] == 'doc.pdf' + assert ref_list[0]['s3_key'] == 'archive/doc.pdf' + assert ref_list[0]['excerpt'] == 'This is the content.' + + assert len(updated_chunks) == 1 + assert updated_chunks[0]['reference_id'] == '1' + + @pytest.mark.offline + def test_multiple_chunks_same_file(self): + """Test multiple chunks from same file get same reference_id.""" + chunks = [ + {'file_path': '/path/doc.pdf', 'content': 'Chunk 1'}, + {'file_path': '/path/doc.pdf', 'content': 'Chunk 2'}, + {'file_path': '/path/doc.pdf', 'content': 'Chunk 3'}, + ] + ref_list, updated_chunks = generate_reference_list_from_chunks(chunks) + + assert len(ref_list) == 1 + assert ref_list[0]['reference_id'] == '1' + # All chunks should have same reference_id + for chunk in updated_chunks: + assert chunk['reference_id'] == '1' + + @pytest.mark.offline + def test_multiple_files_deduplication(self): + """Test multiple files are deduplicated with unique reference_ids.""" + chunks = [ + {'file_path': '/path/doc1.pdf', 'content': 'Content 1'}, + {'file_path': '/path/doc2.pdf', 'content': 'Content 2'}, + {'file_path': '/path/doc1.pdf', 'content': 'Content 1 more'}, + ] + ref_list, updated_chunks = generate_reference_list_from_chunks(chunks) + + assert len(ref_list) == 2 + # doc1 appears twice, so should be reference_id '1' (higher frequency) + # doc2 appears once, so should be reference_id '2' + ref_ids = {r['file_path']: r['reference_id'] for r in ref_list} + assert ref_ids['/path/doc1.pdf'] == '1' + assert ref_ids['/path/doc2.pdf'] == '2' + + @pytest.mark.offline + def test_prioritization_by_frequency(self): + """Test that references are prioritized by frequency.""" + chunks = [ + {'file_path': '/rare.pdf', 'content': 'Rare'}, + {'file_path': '/common.pdf', 'content': 'Common 1'}, + {'file_path': '/common.pdf', 'content': 'Common 2'}, + {'file_path': '/common.pdf', 'content': 'Common 3'}, + {'file_path': '/rare.pdf', 'content': 'Rare 2'}, + ] + ref_list, _ = generate_reference_list_from_chunks(chunks) + + # common.pdf appears 3 times, rare.pdf appears 2 times + # common.pdf should get reference_id '1' + assert ref_list[0]['file_path'] == '/common.pdf' + assert ref_list[0]['reference_id'] == '1' + assert ref_list[1]['file_path'] == '/rare.pdf' + assert ref_list[1]['reference_id'] == '2' + + @pytest.mark.offline + def test_unknown_source_filtered(self): + """Test that 'unknown_source' file paths are filtered out.""" + chunks = [ + {'file_path': '/path/doc.pdf', 'content': 'Valid'}, + {'file_path': 'unknown_source', 'content': 'Unknown'}, + {'file_path': '/path/doc2.pdf', 'content': 'Valid 2'}, + ] + ref_list, updated_chunks = generate_reference_list_from_chunks(chunks) + + # unknown_source should not be in reference list + assert len(ref_list) == 2 + file_paths = [r['file_path'] for r in ref_list] + assert 'unknown_source' not in file_paths + + # Chunk with unknown_source should have empty reference_id + assert updated_chunks[1]['reference_id'] == '' + + @pytest.mark.offline + def test_empty_file_path_filtered(self): + """Test that empty file paths are filtered out.""" + chunks = [ + {'file_path': '/path/doc.pdf', 'content': 'Valid'}, + {'file_path': '', 'content': 'No path'}, + {'content': 'Missing path key'}, + ] + ref_list, updated_chunks = generate_reference_list_from_chunks(chunks) + + assert len(ref_list) == 1 + assert ref_list[0]['file_path'] == '/path/doc.pdf' + + @pytest.mark.offline + def test_s3_key_included(self): + """Test that s3_key is included in reference list.""" + chunks = [ + { + 'file_path': 's3://bucket/archive/doc.pdf', + 'content': 'S3 content', + 's3_key': 'archive/doc.pdf', + } + ] + ref_list, _ = generate_reference_list_from_chunks(chunks) + + assert ref_list[0]['s3_key'] == 'archive/doc.pdf' + assert ref_list[0]['document_title'] == 'doc.pdf' + + @pytest.mark.offline + def test_excerpt_generated_from_first_chunk(self): + """Test that excerpt is generated from first chunk of each file.""" + chunks = [ + {'file_path': '/doc.pdf', 'content': 'First chunk content'}, + {'file_path': '/doc.pdf', 'content': 'Second chunk different'}, + ] + ref_list, _ = generate_reference_list_from_chunks(chunks) + + # Excerpt should be from first chunk + assert ref_list[0]['excerpt'] == 'First chunk content' + + @pytest.mark.offline + def test_excerpt_added_to_each_chunk(self): + """Test that each updated chunk has its own excerpt.""" + chunks = [ + {'file_path': '/doc.pdf', 'content': 'First chunk'}, + {'file_path': '/doc.pdf', 'content': 'Second chunk'}, + ] + _, updated_chunks = generate_reference_list_from_chunks(chunks) + + assert updated_chunks[0]['excerpt'] == 'First chunk' + assert updated_chunks[1]['excerpt'] == 'Second chunk' + + @pytest.mark.offline + def test_original_chunks_not_modified(self): + """Test that original chunks are not modified (returns copies).""" + original_chunks = [ + {'file_path': '/doc.pdf', 'content': 'Content'}, + ] + _, updated_chunks = generate_reference_list_from_chunks(original_chunks) + + # Original should not have reference_id + assert 'reference_id' not in original_chunks[0] + # Updated should have reference_id + assert 'reference_id' in updated_chunks[0] + + @pytest.mark.offline + def test_missing_s3_key_is_none(self): + """Test that missing s3_key results in None.""" + chunks = [ + {'file_path': '/local/doc.pdf', 'content': 'Local file'}, + ] + ref_list, _ = generate_reference_list_from_chunks(chunks) + + assert ref_list[0]['s3_key'] is None + + @pytest.mark.offline + def test_tie_breaking_by_first_appearance(self): + """Test that same-frequency files are ordered by first appearance.""" + chunks = [ + {'file_path': '/doc_b.pdf', 'content': 'B first'}, + {'file_path': '/doc_a.pdf', 'content': 'A second'}, + {'file_path': '/doc_b.pdf', 'content': 'B again'}, + {'file_path': '/doc_a.pdf', 'content': 'A again'}, + ] + ref_list, _ = generate_reference_list_from_chunks(chunks) + + # Both files appear twice, but doc_b appeared first + assert ref_list[0]['file_path'] == '/doc_b.pdf' + assert ref_list[0]['reference_id'] == '1' + assert ref_list[1]['file_path'] == '/doc_a.pdf' + assert ref_list[1]['reference_id'] == '2' diff --git a/tests/test_extraction_prompt_ab.py b/tests/test_extraction_prompt_ab.py index a030be48..6aa12fd7 100644 --- a/tests/test_extraction_prompt_ab.py +++ b/tests/test_extraction_prompt_ab.py @@ -19,9 +19,10 @@ import tiktoken # Add project root to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) -from lightrag.prompt import PROMPTS from lightrag.prompt_optimized import PROMPTS_OPTIMIZED +from lightrag.prompt import PROMPTS + # ============================================================================= # Sample Texts for Testing # ============================================================================= @@ -355,7 +356,7 @@ class TestExtractionPromptAB: """Compare prompts across all sample texts.""" results = [] - for key, sample in SAMPLE_TEXTS.items(): + for _key, sample in SAMPLE_TEXTS.items(): print(f"\nProcessing: {sample['name']}...") original = await run_extraction(PROMPTS, sample["text"]) @@ -410,7 +411,7 @@ async def main() -> None: results = [] - for key, sample in SAMPLE_TEXTS.items(): + for _key, sample in SAMPLE_TEXTS.items(): print(f"Processing: {sample['name']}...") original = await run_extraction(PROMPTS, sample["text"]) diff --git a/tests/test_prompt_accuracy.py b/tests/test_prompt_accuracy.py index c3d88af8..b0fa443a 100644 --- a/tests/test_prompt_accuracy.py +++ b/tests/test_prompt_accuracy.py @@ -16,7 +16,6 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from lightrag.prompt import PROMPTS - # ============================================================================= # Test Data # ============================================================================= diff --git a/tests/test_prompt_quality_deep.py b/tests/test_prompt_quality_deep.py index 130778a1..06c17bcd 100644 --- a/tests/test_prompt_quality_deep.py +++ b/tests/test_prompt_quality_deep.py @@ -16,7 +16,6 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from lightrag.prompt import PROMPTS - # ============================================================================= # Test Data: Entity Extraction (5 Domains) # ============================================================================= @@ -427,7 +426,7 @@ async def test_entity_extraction_deep() -> list[EntityResult]: """Deep test entity extraction on 5 domains.""" results = [] - for domain, data in ENTITY_TEST_TEXTS.items(): + for _domain, data in ENTITY_TEST_TEXTS.items(): print(f" Testing {data['name']}...") prompt = format_entity_prompt(data["text"]) diff --git a/tests/test_s3_client.py b/tests/test_s3_client.py new file mode 100644 index 00000000..872b3759 --- /dev/null +++ b/tests/test_s3_client.py @@ -0,0 +1,618 @@ +"""Tests for S3 client functionality in lightrag/storage/s3_client.py. + +This module tests S3 operations by mocking the aioboto3 session layer, +avoiding the moto/aiobotocore async incompatibility issue. +""" + +from contextlib import asynccontextmanager +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Note: The S3Client in lightrag uses aioboto3 which requires proper async mocking + + +@pytest.fixture +def aws_credentials(monkeypatch): + """Set mock AWS credentials.""" + monkeypatch.setenv('AWS_ACCESS_KEY_ID', 'testing') + monkeypatch.setenv('AWS_SECRET_ACCESS_KEY', 'testing') + monkeypatch.setenv('AWS_DEFAULT_REGION', 'us-east-1') + monkeypatch.setenv('S3_ACCESS_KEY_ID', 'testing') + monkeypatch.setenv('S3_SECRET_ACCESS_KEY', 'testing') + monkeypatch.setenv('S3_BUCKET_NAME', 'test-bucket') + monkeypatch.setenv('S3_REGION', 'us-east-1') + monkeypatch.setenv('S3_ENDPOINT_URL', '') + + +@pytest.fixture +def s3_config(aws_credentials): + """Create S3Config for testing.""" + from lightrag.storage.s3_client import S3Config + + return S3Config( + endpoint_url='', + access_key_id='testing', + secret_access_key='testing', + bucket_name='test-bucket', + region='us-east-1', + ) + + +def create_mock_s3_client(): + """Create a mock S3 client with common operations.""" + mock_client = MagicMock() + + # Storage for mock objects + mock_client._objects = {} + + # head_bucket - succeeds (bucket exists) + mock_client.head_bucket = AsyncMock(return_value={}) + + # put_object + async def mock_put_object(**kwargs): + key = kwargs['Key'] + body = kwargs['Body'] + metadata = kwargs.get('Metadata', {}) + content_type = kwargs.get('ContentType', 'application/octet-stream') + + # Read body if it's a file-like object + if hasattr(body, 'read'): + content = body.read() + else: + content = body + + mock_client._objects[key] = { + 'Body': content, + 'Metadata': metadata, + 'ContentType': content_type, + } + return {'ETag': '"mock-etag"'} + + mock_client.put_object = AsyncMock(side_effect=mock_put_object) + + # get_object + async def mock_get_object(**kwargs): + key = kwargs['Key'] + if key not in mock_client._objects: + from botocore.exceptions import ClientError + raise ClientError( + {'Error': {'Code': 'NoSuchKey', 'Message': 'Not found'}}, + 'GetObject' + ) + + obj = mock_client._objects[key] + body_mock = MagicMock() + body_mock.read = AsyncMock(return_value=obj['Body']) + + return { + 'Body': body_mock, + 'Metadata': obj['Metadata'], + 'ContentType': obj['ContentType'], + } + + mock_client.get_object = AsyncMock(side_effect=mock_get_object) + + # head_object (for object_exists) + async def mock_head_object(**kwargs): + key = kwargs['Key'] + if key not in mock_client._objects: + from botocore.exceptions import ClientError + raise ClientError( + {'Error': {'Code': '404', 'Message': 'Not found'}}, + 'HeadObject' + ) + return {'ContentLength': len(mock_client._objects[key]['Body'])} + + mock_client.head_object = AsyncMock(side_effect=mock_head_object) + + # delete_object + async def mock_delete_object(**kwargs): + key = kwargs['Key'] + if key in mock_client._objects: + del mock_client._objects[key] + return {} + + mock_client.delete_object = AsyncMock(side_effect=mock_delete_object) + + # copy_object + async def mock_copy_object(**kwargs): + source = kwargs['CopySource'] + dest_key = kwargs['Key'] + + # CopySource is like {'Bucket': 'bucket', 'Key': 'key'} + source_key = source['Key'] + + if source_key not in mock_client._objects: + from botocore.exceptions import ClientError + raise ClientError( + {'Error': {'Code': 'NoSuchKey', 'Message': 'Not found'}}, + 'CopyObject' + ) + + mock_client._objects[dest_key] = mock_client._objects[source_key].copy() + return {} + + mock_client.copy_object = AsyncMock(side_effect=mock_copy_object) + + # list_objects_v2 + async def mock_list_objects_v2(**kwargs): + prefix = kwargs.get('Prefix', '') + contents = [] + + for key, obj in mock_client._objects.items(): + if key.startswith(prefix): + contents.append({ + 'Key': key, + 'Size': len(obj['Body']), + 'LastModified': '2024-01-01T00:00:00Z', + }) + + return {'Contents': contents} if contents else {} + + mock_client.list_objects_v2 = AsyncMock(side_effect=mock_list_objects_v2) + + # get_paginator for list_staging - returns async paginator + class MockPaginator: + def __init__(self, objects_dict): + self._objects = objects_dict + + def paginate(self, **kwargs): + return MockPaginatorIterator(self._objects, kwargs.get('Prefix', '')) + + class MockPaginatorIterator: + def __init__(self, objects_dict, prefix): + self._objects = objects_dict + self._prefix = prefix + self._done = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self._done: + raise StopAsyncIteration + + self._done = True + from datetime import datetime + contents = [] + for key, obj in self._objects.items(): + if key.startswith(self._prefix): + contents.append({ + 'Key': key, + 'Size': len(obj['Body']), + 'LastModified': datetime(2024, 1, 1), + }) + return {'Contents': contents} if contents else {} + + def mock_get_paginator(operation_name): + return MockPaginator(mock_client._objects) + + mock_client.get_paginator = MagicMock(side_effect=mock_get_paginator) + + # generate_presigned_url - the code awaits this, so return an awaitable + async def mock_generate_presigned_url(ClientMethod, Params, ExpiresIn=3600): + key = Params.get('Key', 'unknown') + bucket = Params.get('Bucket', 'bucket') + return f'https://{bucket}.s3.amazonaws.com/{key}?signature=mock' + + mock_client.generate_presigned_url = mock_generate_presigned_url + + return mock_client + + +@pytest.fixture +def mock_s3_session(): + """Create a mock aioboto3 session that returns a mock S3 client.""" + mock_session = MagicMock() + mock_client = create_mock_s3_client() + + @asynccontextmanager + async def mock_client_context(*args, **kwargs): + yield mock_client + + # Return a NEW context manager each time client() is called + mock_session.client = MagicMock(side_effect=lambda *args, **kwargs: mock_client_context()) + + return mock_session, mock_client + + +# ============================================================================ +# Unit Tests for Key Generation (no mocking needed) +# ============================================================================ + + +class TestKeyGeneration: + """Tests for S3 key generation methods.""" + + @pytest.mark.offline + def test_make_staging_key(self, s3_config): + """Test staging key format.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + key = client._make_staging_key('default', 'doc123', 'report.pdf') + assert key == 'staging/default/doc123/report.pdf' + + @pytest.mark.offline + def test_make_staging_key_sanitizes_slashes(self, s3_config): + """Test that slashes in filename are sanitized.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + key = client._make_staging_key('default', 'doc123', 'path/to/file.pdf') + assert key == 'staging/default/doc123/path_to_file.pdf' + assert '//' not in key + + @pytest.mark.offline + def test_make_staging_key_sanitizes_backslashes(self, s3_config): + """Test that backslashes in filename are sanitized.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + key = client._make_staging_key('default', 'doc123', 'path\\to\\file.pdf') + assert key == 'staging/default/doc123/path_to_file.pdf' + + @pytest.mark.offline + def test_make_archive_key(self, s3_config): + """Test archive key format.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + key = client._make_archive_key('workspace1', 'doc456', 'data.json') + assert key == 'archive/workspace1/doc456/data.json' + + @pytest.mark.offline + def test_staging_to_archive_key(self, s3_config): + """Test staging to archive key transformation.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + staging_key = 'staging/default/doc123/report.pdf' + archive_key = client._staging_to_archive_key(staging_key) + assert archive_key == 'archive/default/doc123/report.pdf' + + @pytest.mark.offline + def test_staging_to_archive_key_non_staging(self, s3_config): + """Test that non-staging keys are returned unchanged.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + key = 'archive/default/doc123/report.pdf' + result = client._staging_to_archive_key(key) + assert result == key + + @pytest.mark.offline + def test_get_s3_url(self, s3_config): + """Test S3 URL generation.""" + from lightrag.storage.s3_client import S3Client + + client = S3Client(config=s3_config) + url = client.get_s3_url('archive/default/doc123/report.pdf') + assert url == 's3://test-bucket/archive/default/doc123/report.pdf' + + +# ============================================================================ +# Integration Tests with Mocked S3 Session +# ============================================================================ + + +@pytest.mark.offline +class TestS3ClientOperations: + """Tests for S3 client operations using mocked session.""" + + @pytest.mark.asyncio + async def test_initialize_creates_bucket(self, s3_config, mock_s3_session): + """Test that initialize checks bucket exists.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + assert client._initialized is True + mock_client.head_bucket.assert_called_once() + + await client.finalize() + + @pytest.mark.asyncio + async def test_upload_to_staging(self, s3_config, mock_s3_session): + """Test uploading content to staging.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=b'Hello, World!', + filename='test.txt', + content_type='text/plain', + ) + + assert s3_key == 'staging/default/doc123/test.txt' + mock_client.put_object.assert_called_once() + + await client.finalize() + + @pytest.mark.asyncio + async def test_upload_string_content(self, s3_config, mock_s3_session): + """Test uploading string content (should be encoded to bytes).""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content='String content', # String, not bytes + filename='test.txt', + ) + + # Verify we can retrieve it + content, metadata = await client.get_object(s3_key) + assert content == b'String content' + + await client.finalize() + + @pytest.mark.asyncio + async def test_get_object(self, s3_config, mock_s3_session): + """Test retrieving uploaded object.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + # Upload + test_content = b'Test content for retrieval' + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=test_content, + filename='test.txt', + ) + + # Retrieve + content, metadata = await client.get_object(s3_key) + + assert content == test_content + assert metadata.get('workspace') == 'default' + assert metadata.get('doc_id') == 'doc123' + assert 'content_hash' in metadata + + await client.finalize() + + @pytest.mark.asyncio + async def test_move_to_archive(self, s3_config, mock_s3_session): + """Test moving object from staging to archive.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + # Upload to staging + test_content = b'Content to archive' + staging_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=test_content, + filename='test.txt', + ) + + # Move to archive + archive_key = await client.move_to_archive(staging_key) + + assert archive_key == 'archive/default/doc123/test.txt' + + # Verify staging key no longer exists + assert not await client.object_exists(staging_key) + + # Verify archive key exists and has correct content + assert await client.object_exists(archive_key) + content, _ = await client.get_object(archive_key) + assert content == test_content + + await client.finalize() + + @pytest.mark.asyncio + async def test_delete_object(self, s3_config, mock_s3_session): + """Test deleting an object.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + # Upload + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=b'Content to delete', + filename='test.txt', + ) + + assert await client.object_exists(s3_key) + + # Delete + await client.delete_object(s3_key) + + # Verify deleted + assert not await client.object_exists(s3_key) + + await client.finalize() + + @pytest.mark.asyncio + async def test_list_staging(self, s3_config, mock_s3_session): + """Test listing objects in staging.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + # Upload multiple objects + await client.upload_to_staging('default', 'doc1', b'Content 1', 'file1.txt') + await client.upload_to_staging('default', 'doc2', b'Content 2', 'file2.txt') + await client.upload_to_staging('other', 'doc3', b'Content 3', 'file3.txt') + + # List only 'default' workspace + objects = await client.list_staging('default') + + assert len(objects) == 2 + keys = [obj['key'] for obj in objects] + assert 'staging/default/doc1/file1.txt' in keys + assert 'staging/default/doc2/file2.txt' in keys + assert 'staging/other/doc3/file3.txt' not in keys + + await client.finalize() + + @pytest.mark.asyncio + async def test_object_exists_true(self, s3_config, mock_s3_session): + """Test object_exists returns True for existing object.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=b'Test', + filename='test.txt', + ) + + assert await client.object_exists(s3_key) is True + + await client.finalize() + + @pytest.mark.asyncio + async def test_object_exists_false(self, s3_config, mock_s3_session): + """Test object_exists returns False for non-existing object.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + assert await client.object_exists('nonexistent/key') is False + + await client.finalize() + + @pytest.mark.asyncio + async def test_get_presigned_url(self, s3_config, mock_s3_session): + """Test generating presigned URL.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=b'Test', + filename='test.txt', + ) + + url = await client.get_presigned_url(s3_key) + + # URL should be a string containing the bucket + assert isinstance(url, str) + assert 'test-bucket' in url + + await client.finalize() + + @pytest.mark.asyncio + async def test_upload_with_metadata(self, s3_config, mock_s3_session): + """Test uploading with custom metadata.""" + from lightrag.storage.s3_client import S3Client, S3ClientManager + + mock_session, mock_client = mock_s3_session + + with patch.object(S3ClientManager, 'get_session', return_value=mock_session): + client = S3Client(config=s3_config) + await client.initialize() + + custom_metadata = {'author': 'test-user', 'version': '1.0'} + + s3_key = await client.upload_to_staging( + workspace='default', + doc_id='doc123', + content=b'Test', + filename='test.txt', + metadata=custom_metadata, + ) + + _, metadata = await client.get_object(s3_key) + + # Custom metadata should be included + assert metadata.get('author') == 'test-user' + assert metadata.get('version') == '1.0' + # Built-in metadata should also be present + assert metadata.get('workspace') == 'default' + + await client.finalize() + + +# ============================================================================ +# S3Config Tests +# ============================================================================ + + +class TestS3Config: + """Tests for S3Config validation.""" + + @pytest.mark.offline + def test_config_requires_credentials(self, monkeypatch): + """Test that S3Config raises error without credentials.""" + from lightrag.storage.s3_client import S3Config + + monkeypatch.setenv('S3_ACCESS_KEY_ID', '') + monkeypatch.setenv('S3_SECRET_ACCESS_KEY', '') + + with pytest.raises(ValueError, match='S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY must be set'): + S3Config( + access_key_id='', + secret_access_key='', + ) + + @pytest.mark.offline + def test_config_with_valid_credentials(self, aws_credentials): + """Test that S3Config initializes with valid credentials.""" + from lightrag.storage.s3_client import S3Config + + config = S3Config( + access_key_id='valid-key', + secret_access_key='valid-secret', + ) + + assert config.access_key_id == 'valid-key' + assert config.secret_access_key == 'valid-secret' diff --git a/tests/test_search_routes.py b/tests/test_search_routes.py new file mode 100644 index 00000000..f82774c4 --- /dev/null +++ b/tests/test_search_routes.py @@ -0,0 +1,406 @@ +"""Tests for search routes in lightrag/api/routers/search_routes.py. + +This module tests the BM25 full-text search endpoint using httpx AsyncClient +and FastAPI's TestClient pattern with mocked PostgreSQLDB. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# Mock the config module BEFORE importing search_routes to prevent +# argparse from trying to parse pytest arguments as server arguments +mock_global_args = MagicMock() +mock_global_args.token_secret = 'test-secret' +mock_global_args.jwt_secret_key = 'test-jwt-secret' +mock_global_args.jwt_algorithm = 'HS256' +mock_global_args.jwt_expire_hours = 24 +mock_global_args.username = None +mock_global_args.password = None +mock_global_args.guest_token = None + +# Pre-populate sys.modules with mocked config +mock_config_module = MagicMock() +mock_config_module.global_args = mock_global_args +sys.modules['lightrag.api.config'] = mock_config_module + +# Also mock the auth module to prevent initialization issues +mock_auth_module = MagicMock() +mock_auth_module.auth_handler = MagicMock() +sys.modules['lightrag.api.auth'] = mock_auth_module + +# Now import FastAPI components (after mocking) +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from pydantic import BaseModel, Field +from typing import Annotated, Any, ClassVar + +# Import the components we need from search_routes without triggering full init +# We'll recreate the essential parts for testing + + +class SearchResult(BaseModel): + """A single search result (chunk match).""" + + id: str = Field(description='Chunk ID') + full_doc_id: str = Field(description='Parent document ID') + chunk_order_index: int = Field(description='Position in document') + tokens: int = Field(description='Token count') + content: str = Field(description='Chunk content') + file_path: str | None = Field(default=None, description='Source file path') + s3_key: str | None = Field(default=None, description='S3 key for source document') + char_start: int | None = Field(default=None, description='Character offset start') + char_end: int | None = Field(default=None, description='Character offset end') + score: float = Field(description='BM25 relevance score') + + +class SearchResponse(BaseModel): + """Response model for search endpoint.""" + + query: str = Field(description='Original search query') + results: list[SearchResult] = Field(description='Matching chunks') + count: int = Field(description='Number of results returned') + workspace: str = Field(description='Workspace searched') + + +def create_test_search_routes(db: Any, api_key: str | None = None): + """Create search routes for testing (simplified version without auth dep).""" + from fastapi import APIRouter, HTTPException, Query + + router = APIRouter(prefix='/search', tags=['search']) + + @router.get('', response_model=SearchResponse) + async def search( + q: Annotated[str, Query(description='Search query', min_length=1)], + limit: Annotated[int, Query(description='Max results', ge=1, le=100)] = 10, + workspace: Annotated[str, Query(description='Workspace')] = 'default', + ) -> SearchResponse: + """Perform BM25 full-text search on chunks.""" + try: + results = await db.full_text_search( + query=q, + workspace=workspace, + limit=limit, + ) + + search_results = [ + SearchResult( + id=r.get('id', ''), + full_doc_id=r.get('full_doc_id', ''), + chunk_order_index=r.get('chunk_order_index', 0), + tokens=r.get('tokens', 0), + content=r.get('content', ''), + file_path=r.get('file_path'), + s3_key=r.get('s3_key'), + char_start=r.get('char_start'), + char_end=r.get('char_end'), + score=float(r.get('score', 0)), + ) + for r in results + ] + + return SearchResponse( + query=q, + results=search_results, + count=len(search_results), + workspace=workspace, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f'Search failed: {e}') from e + + return router + + +@pytest.fixture +def mock_db(): + """Create a mock PostgreSQLDB instance.""" + db = MagicMock() + db.full_text_search = AsyncMock() + return db + + +@pytest.fixture +def app(mock_db): + """Create FastAPI app with search routes.""" + app = FastAPI() + router = create_test_search_routes(db=mock_db, api_key=None) + app.include_router(router) + return app + + +@pytest.fixture +async def client(app): + """Create async HTTP client for testing.""" + async with AsyncClient( + transport=ASGITransport(app=app), + base_url='http://test', + ) as client: + yield client + + +# ============================================================================ +# Search Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestSearchEndpoint: + """Tests for GET /search endpoint.""" + + @pytest.mark.asyncio + async def test_search_returns_results(self, client, mock_db): + """Test that search returns properly formatted results.""" + mock_db.full_text_search.return_value = [ + { + 'id': 'chunk-1', + 'full_doc_id': 'doc-1', + 'chunk_order_index': 0, + 'tokens': 100, + 'content': 'This is test content about machine learning.', + 'file_path': '/path/to/doc.pdf', + 's3_key': 'archive/default/doc-1/doc.pdf', + 'char_start': 0, + 'char_end': 100, + 'score': 0.85, + } + ] + + response = await client.get('/search', params={'q': 'machine learning'}) + + assert response.status_code == 200 + data = response.json() + assert data['query'] == 'machine learning' + assert data['count'] == 1 + assert data['workspace'] == 'default' + assert len(data['results']) == 1 + + result = data['results'][0] + assert result['id'] == 'chunk-1' + assert result['content'] == 'This is test content about machine learning.' + assert result['score'] == 0.85 + + @pytest.mark.asyncio + async def test_search_includes_char_positions(self, client, mock_db): + """Test that char_start/char_end are included in results.""" + mock_db.full_text_search.return_value = [ + { + 'id': 'chunk-1', + 'full_doc_id': 'doc-1', + 'chunk_order_index': 0, + 'tokens': 50, + 'content': 'Test content', + 'char_start': 100, + 'char_end': 200, + 'score': 0.75, + } + ] + + response = await client.get('/search', params={'q': 'test'}) + + assert response.status_code == 200 + result = response.json()['results'][0] + assert result['char_start'] == 100 + assert result['char_end'] == 200 + + @pytest.mark.asyncio + async def test_search_includes_s3_key(self, client, mock_db): + """Test that s3_key is included in results when present.""" + mock_db.full_text_search.return_value = [ + { + 'id': 'chunk-1', + 'full_doc_id': 'doc-1', + 'chunk_order_index': 0, + 'tokens': 50, + 'content': 'Test content', + 's3_key': 'archive/default/doc-1/report.pdf', + 'file_path': 's3://bucket/archive/default/doc-1/report.pdf', + 'score': 0.75, + } + ] + + response = await client.get('/search', params={'q': 'test'}) + + assert response.status_code == 200 + result = response.json()['results'][0] + assert result['s3_key'] == 'archive/default/doc-1/report.pdf' + assert result['file_path'] == 's3://bucket/archive/default/doc-1/report.pdf' + + @pytest.mark.asyncio + async def test_search_null_s3_key(self, client, mock_db): + """Test that null s3_key is handled correctly.""" + mock_db.full_text_search.return_value = [ + { + 'id': 'chunk-1', + 'full_doc_id': 'doc-1', + 'chunk_order_index': 0, + 'tokens': 50, + 'content': 'Test content', + 's3_key': None, + 'file_path': '/local/path/doc.pdf', + 'score': 0.75, + } + ] + + response = await client.get('/search', params={'q': 'test'}) + + assert response.status_code == 200 + result = response.json()['results'][0] + assert result['s3_key'] is None + assert result['file_path'] == '/local/path/doc.pdf' + + @pytest.mark.asyncio + async def test_search_empty_results(self, client, mock_db): + """Test that empty results are returned correctly.""" + mock_db.full_text_search.return_value = [] + + response = await client.get('/search', params={'q': 'nonexistent'}) + + assert response.status_code == 200 + data = response.json() + assert data['count'] == 0 + assert data['results'] == [] + + @pytest.mark.asyncio + async def test_search_limit_parameter(self, client, mock_db): + """Test that limit parameter is passed to database.""" + mock_db.full_text_search.return_value = [] + + await client.get('/search', params={'q': 'test', 'limit': 25}) + + mock_db.full_text_search.assert_called_once_with( + query='test', + workspace='default', + limit=25, + ) + + @pytest.mark.asyncio + async def test_search_workspace_parameter(self, client, mock_db): + """Test that workspace parameter is passed to database.""" + mock_db.full_text_search.return_value = [] + + await client.get('/search', params={'q': 'test', 'workspace': 'custom'}) + + mock_db.full_text_search.assert_called_once_with( + query='test', + workspace='custom', + limit=10, # default + ) + + @pytest.mark.asyncio + async def test_search_multiple_results(self, client, mock_db): + """Test that multiple results are returned correctly.""" + mock_db.full_text_search.return_value = [ + { + 'id': f'chunk-{i}', + 'full_doc_id': f'doc-{i}', + 'chunk_order_index': i, + 'tokens': 100, + 'content': f'Content {i}', + 'score': 0.9 - i * 0.1, + } + for i in range(5) + ] + + response = await client.get('/search', params={'q': 'test', 'limit': 5}) + + assert response.status_code == 200 + data = response.json() + assert data['count'] == 5 + assert len(data['results']) == 5 + # Results should be in order by score (as returned by db) + assert data['results'][0]['id'] == 'chunk-0' + assert data['results'][4]['id'] == 'chunk-4' + + +# ============================================================================ +# Validation Tests +# ============================================================================ + + +@pytest.mark.offline +class TestSearchValidation: + """Tests for search endpoint validation.""" + + @pytest.mark.asyncio + async def test_search_empty_query_rejected(self, client, mock_db): + """Test that empty query string is rejected with 422.""" + response = await client.get('/search', params={'q': ''}) + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_search_missing_query_rejected(self, client, mock_db): + """Test that missing query parameter is rejected with 422.""" + response = await client.get('/search') + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_search_limit_too_low_rejected(self, client, mock_db): + """Test that limit < 1 is rejected.""" + response = await client.get('/search', params={'q': 'test', 'limit': 0}) + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_search_limit_too_high_rejected(self, client, mock_db): + """Test that limit > 100 is rejected.""" + response = await client.get('/search', params={'q': 'test', 'limit': 101}) + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_search_limit_boundary_valid(self, client, mock_db): + """Test that limit at boundaries (1 and 100) are accepted.""" + mock_db.full_text_search.return_value = [] + + # Test lower boundary + response = await client.get('/search', params={'q': 'test', 'limit': 1}) + assert response.status_code == 200 + + # Test upper boundary + response = await client.get('/search', params={'q': 'test', 'limit': 100}) + assert response.status_code == 200 + + +# ============================================================================ +# Error Handling Tests +# ============================================================================ + + +@pytest.mark.offline +class TestSearchErrors: + """Tests for search endpoint error handling.""" + + @pytest.mark.asyncio + async def test_search_database_error(self, client, mock_db): + """Test that database errors return 500.""" + mock_db.full_text_search.side_effect = Exception('Database connection failed') + + response = await client.get('/search', params={'q': 'test'}) + + assert response.status_code == 500 + assert 'Database connection failed' in response.json()['detail'] + + @pytest.mark.asyncio + async def test_search_handles_missing_fields(self, client, mock_db): + """Test that missing fields in db results are handled with defaults.""" + mock_db.full_text_search.return_value = [ + { + 'id': 'chunk-1', + 'full_doc_id': 'doc-1', + # Missing many fields + } + ] + + response = await client.get('/search', params={'q': 'test'}) + + assert response.status_code == 200 + result = response.json()['results'][0] + # Should have defaults + assert result['chunk_order_index'] == 0 + assert result['tokens'] == 0 + assert result['content'] == '' + assert result['score'] == 0.0 diff --git a/tests/test_upload_routes.py b/tests/test_upload_routes.py new file mode 100644 index 00000000..4cbd2c93 --- /dev/null +++ b/tests/test_upload_routes.py @@ -0,0 +1,724 @@ +"""Tests for upload routes in lightrag/api/routers/upload_routes.py. + +This module tests the S3 document staging endpoints using httpx AsyncClient +and FastAPI's TestClient pattern with mocked S3Client and LightRAG. +""" + +import sys +from io import BytesIO +from typing import Annotated, Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# Mock the config module BEFORE importing to prevent argparse issues +mock_global_args = MagicMock() +mock_global_args.token_secret = 'test-secret' +mock_global_args.jwt_secret_key = 'test-jwt-secret' +mock_global_args.jwt_algorithm = 'HS256' +mock_global_args.jwt_expire_hours = 24 +mock_global_args.username = None +mock_global_args.password = None +mock_global_args.guest_token = None + +mock_config_module = MagicMock() +mock_config_module.global_args = mock_global_args +sys.modules['lightrag.api.config'] = mock_config_module + +mock_auth_module = MagicMock() +mock_auth_module.auth_handler = MagicMock() +sys.modules['lightrag.api.auth'] = mock_auth_module + +# Now import FastAPI components +from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile +from httpx import ASGITransport, AsyncClient +from pydantic import BaseModel, Field + + +# Recreate models for testing (to avoid import chain issues) +class UploadResponse(BaseModel): + """Response model for document upload.""" + + status: str + doc_id: str + s3_key: str + s3_url: str + message: str | None = None + + +class StagedDocument(BaseModel): + """Model for a staged document.""" + + key: str + size: int + last_modified: str + + +class ListStagedResponse(BaseModel): + """Response model for listing staged documents.""" + + workspace: str + documents: list[StagedDocument] + count: int + + +class PresignedUrlResponse(BaseModel): + """Response model for presigned URL.""" + + s3_key: str + presigned_url: str + expiry_seconds: int + + +class ProcessS3Request(BaseModel): + """Request model for processing a document from S3 staging.""" + + s3_key: str + doc_id: str | None = None + archive_after_processing: bool = True + + +class ProcessS3Response(BaseModel): + """Response model for S3 document processing.""" + + status: str + track_id: str + doc_id: str + s3_key: str + archive_key: str | None = None + message: str | None = None + + +def create_test_upload_routes( + rag: Any, + s3_client: Any, + api_key: str | None = None, +) -> APIRouter: + """Create upload routes for testing (simplified without auth).""" + router = APIRouter(prefix='/upload', tags=['upload']) + + @router.post('', response_model=UploadResponse) + async def upload_document( + file: Annotated[UploadFile, File(description='Document file')], + workspace: Annotated[str, Form(description='Workspace')] = 'default', + doc_id: Annotated[str | None, Form(description='Document ID')] = None, + ) -> UploadResponse: + """Upload a document to S3 staging.""" + try: + content = await file.read() + + if not content: + raise HTTPException(status_code=400, detail='Empty file') + + # Generate doc_id if not provided + if not doc_id: + import hashlib + + doc_id = 'doc_' + hashlib.md5(content).hexdigest()[:8] + + content_type = file.content_type or 'application/octet-stream' + + s3_key = await s3_client.upload_to_staging( + workspace=workspace, + doc_id=doc_id, + content=content, + filename=file.filename or f'{doc_id}.bin', + content_type=content_type, + metadata={ + 'original_size': str(len(content)), + 'content_type': content_type, + }, + ) + + s3_url = s3_client.get_s3_url(s3_key) + + return UploadResponse( + status='uploaded', + doc_id=doc_id, + s3_key=s3_key, + s3_url=s3_url, + message='Document staged for processing', + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f'Upload failed: {e}') from e + + @router.get('/staged', response_model=ListStagedResponse) + async def list_staged(workspace: str = 'default') -> ListStagedResponse: + """List documents in staging.""" + try: + objects = await s3_client.list_staging(workspace) + + documents = [ + StagedDocument( + key=obj['key'], + size=obj['size'], + last_modified=obj['last_modified'], + ) + for obj in objects + ] + + return ListStagedResponse( + workspace=workspace, + documents=documents, + count=len(documents), + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f'Failed to list staged documents: {e}') from e + + @router.get('/presigned-url', response_model=PresignedUrlResponse) + async def get_presigned_url( + s3_key: str, + expiry: int = 3600, + ) -> PresignedUrlResponse: + """Get presigned URL for a document.""" + try: + if not await s3_client.object_exists(s3_key): + raise HTTPException(status_code=404, detail='Object not found') + + url = await s3_client.get_presigned_url(s3_key, expiry=expiry) + + return PresignedUrlResponse( + s3_key=s3_key, + presigned_url=url, + expiry_seconds=expiry, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f'Failed to generate presigned URL: {e}') from e + + @router.delete('/staged/{doc_id}') + async def delete_staged( + doc_id: str, + workspace: str = 'default', + ) -> dict[str, str]: + """Delete a staged document.""" + try: + prefix = f'staging/{workspace}/{doc_id}/' + objects = await s3_client.list_staging(workspace) + + to_delete = [obj['key'] for obj in objects if obj['key'].startswith(prefix)] + + if not to_delete: + raise HTTPException(status_code=404, detail='Document not found in staging') + + for key in to_delete: + await s3_client.delete_object(key) + + return { + 'status': 'deleted', + 'doc_id': doc_id, + 'deleted_count': str(len(to_delete)), + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f'Failed to delete staged document: {e}') from e + + @router.post('/process', response_model=ProcessS3Response) + async def process_from_s3(request: ProcessS3Request) -> ProcessS3Response: + """Process a staged document through the RAG pipeline.""" + try: + s3_key = request.s3_key + + if not await s3_client.object_exists(s3_key): + raise HTTPException( + status_code=404, + detail=f'Document not found in S3: {s3_key}', + ) + + content_bytes, metadata = await s3_client.get_object(s3_key) + + doc_id = request.doc_id + if not doc_id: + parts = s3_key.split('/') + doc_id = parts[2] if len(parts) >= 3 else 'doc_unknown' + + content_type = metadata.get('content_type', 'application/octet-stream') + s3_url = s3_client.get_s3_url(s3_key) + + # Try to decode as text + try: + text_content = content_bytes.decode('utf-8') + except UnicodeDecodeError: + raise HTTPException( + status_code=400, + detail=f'Cannot process binary content type: {content_type}.', + ) from None + + if not text_content.strip(): + raise HTTPException( + status_code=400, + detail='Document content is empty after decoding', + ) + + # Process through RAG + track_id = await rag.ainsert( + input=text_content, + ids=doc_id, + file_paths=s3_url, + ) + + archive_key = None + if request.archive_after_processing: + try: + archive_key = await s3_client.move_to_archive(s3_key) + except Exception: + pass # Don't fail if archive fails + + return ProcessS3Response( + status='processing_complete', + track_id=track_id, + doc_id=doc_id, + s3_key=s3_key, + archive_key=archive_key, + message='Document processed and stored in RAG pipeline', + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f'Failed to process S3 document: {e}') from e + + return router + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_s3_client(): + """Create a mock S3Client.""" + client = MagicMock() + client.upload_to_staging = AsyncMock() + client.list_staging = AsyncMock() + client.get_presigned_url = AsyncMock() + client.object_exists = AsyncMock() + client.delete_object = AsyncMock() + client.get_object = AsyncMock() + client.move_to_archive = AsyncMock() + client.get_s3_url = MagicMock() + return client + + +@pytest.fixture +def mock_rag(): + """Create a mock LightRAG instance.""" + rag = MagicMock() + rag.ainsert = AsyncMock() + return rag + + +@pytest.fixture +def app(mock_rag, mock_s3_client): + """Create FastAPI app with upload routes.""" + app = FastAPI() + router = create_test_upload_routes( + rag=mock_rag, + s3_client=mock_s3_client, + api_key=None, + ) + app.include_router(router) + return app + + +@pytest.fixture +async def client(app): + """Create async HTTP client for testing.""" + async with AsyncClient( + transport=ASGITransport(app=app), + base_url='http://test', + ) as client: + yield client + + +# ============================================================================ +# Upload Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestUploadEndpoint: + """Tests for POST /upload endpoint.""" + + @pytest.mark.asyncio + async def test_upload_creates_staging_key(self, client, mock_s3_client): + """Test that upload creates correct S3 staging key.""" + mock_s3_client.upload_to_staging.return_value = 'staging/default/doc_abc123/test.txt' + mock_s3_client.get_s3_url.return_value = 's3://bucket/staging/default/doc_abc123/test.txt' + + files = {'file': ('test.txt', b'Hello, World!', 'text/plain')} + data = {'workspace': 'default', 'doc_id': 'doc_abc123'} + + response = await client.post('/upload', files=files, data=data) + + assert response.status_code == 200 + data = response.json() + assert data['status'] == 'uploaded' + assert data['doc_id'] == 'doc_abc123' + assert data['s3_key'] == 'staging/default/doc_abc123/test.txt' + + mock_s3_client.upload_to_staging.assert_called_once() + call_args = mock_s3_client.upload_to_staging.call_args + assert call_args.kwargs['workspace'] == 'default' + assert call_args.kwargs['doc_id'] == 'doc_abc123' + assert call_args.kwargs['content'] == b'Hello, World!' + + @pytest.mark.asyncio + async def test_upload_auto_generates_doc_id(self, client, mock_s3_client): + """Test that doc_id is auto-generated if not provided.""" + mock_s3_client.upload_to_staging.return_value = 'staging/default/doc_auto/test.txt' + mock_s3_client.get_s3_url.return_value = 's3://bucket/staging/default/doc_auto/test.txt' + + files = {'file': ('test.txt', b'Test content', 'text/plain')} + data = {'workspace': 'default'} + + response = await client.post('/upload', files=files, data=data) + + assert response.status_code == 200 + data = response.json() + assert data['doc_id'].startswith('doc_') + + @pytest.mark.asyncio + async def test_upload_empty_file_rejected(self, client, mock_s3_client): + """Test that empty files are rejected.""" + files = {'file': ('empty.txt', b'', 'text/plain')} + + response = await client.post('/upload', files=files) + + assert response.status_code == 400 + assert 'Empty file' in response.json()['detail'] + + @pytest.mark.asyncio + async def test_upload_returns_s3_url(self, client, mock_s3_client): + """Test that upload returns S3 URL.""" + mock_s3_client.upload_to_staging.return_value = 'staging/default/doc_xyz/file.pdf' + mock_s3_client.get_s3_url.return_value = 's3://mybucket/staging/default/doc_xyz/file.pdf' + + files = {'file': ('file.pdf', b'PDF content', 'application/pdf')} + + response = await client.post('/upload', files=files) + + assert response.status_code == 200 + assert response.json()['s3_url'] == 's3://mybucket/staging/default/doc_xyz/file.pdf' + + @pytest.mark.asyncio + async def test_upload_handles_s3_error(self, client, mock_s3_client): + """Test that S3 errors are handled.""" + mock_s3_client.upload_to_staging.side_effect = Exception('S3 connection failed') + + files = {'file': ('test.txt', b'Content', 'text/plain')} + + response = await client.post('/upload', files=files) + + assert response.status_code == 500 + assert 'S3 connection failed' in response.json()['detail'] + + +# ============================================================================ +# List Staged Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestListStagedEndpoint: + """Tests for GET /upload/staged endpoint.""" + + @pytest.mark.asyncio + async def test_list_staged_returns_documents(self, client, mock_s3_client): + """Test that list returns staged documents.""" + mock_s3_client.list_staging.return_value = [ + { + 'key': 'staging/default/doc1/file.pdf', + 'size': 1024, + 'last_modified': '2024-01-01T00:00:00Z', + }, + { + 'key': 'staging/default/doc2/report.docx', + 'size': 2048, + 'last_modified': '2024-01-02T00:00:00Z', + }, + ] + + response = await client.get('/upload/staged') + + assert response.status_code == 200 + data = response.json() + assert data['workspace'] == 'default' + assert data['count'] == 2 + assert len(data['documents']) == 2 + assert data['documents'][0]['key'] == 'staging/default/doc1/file.pdf' + + @pytest.mark.asyncio + async def test_list_staged_empty(self, client, mock_s3_client): + """Test that empty staging returns empty list.""" + mock_s3_client.list_staging.return_value = [] + + response = await client.get('/upload/staged') + + assert response.status_code == 200 + data = response.json() + assert data['count'] == 0 + assert data['documents'] == [] + + @pytest.mark.asyncio + async def test_list_staged_custom_workspace(self, client, mock_s3_client): + """Test listing documents in custom workspace.""" + mock_s3_client.list_staging.return_value = [] + + response = await client.get('/upload/staged', params={'workspace': 'custom'}) + + assert response.status_code == 200 + assert response.json()['workspace'] == 'custom' + mock_s3_client.list_staging.assert_called_once_with('custom') + + +# ============================================================================ +# Presigned URL Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestPresignedUrlEndpoint: + """Tests for GET /upload/presigned-url endpoint.""" + + @pytest.mark.asyncio + async def test_presigned_url_returns_url(self, client, mock_s3_client): + """Test that presigned URL is returned.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_presigned_url.return_value = 'https://s3.example.com/signed-url?token=xyz' + + response = await client.get( + '/upload/presigned-url', + params={'s3_key': 'staging/default/doc1/file.pdf'}, + ) + + assert response.status_code == 200 + data = response.json() + assert data['s3_key'] == 'staging/default/doc1/file.pdf' + assert data['presigned_url'] == 'https://s3.example.com/signed-url?token=xyz' + assert data['expiry_seconds'] == 3600 + + @pytest.mark.asyncio + async def test_presigned_url_custom_expiry(self, client, mock_s3_client): + """Test custom expiry time.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_presigned_url.return_value = 'https://signed-url' + + response = await client.get( + '/upload/presigned-url', + params={'s3_key': 'test/key', 'expiry': 7200}, + ) + + assert response.status_code == 200 + assert response.json()['expiry_seconds'] == 7200 + mock_s3_client.get_presigned_url.assert_called_once_with('test/key', expiry=7200) + + @pytest.mark.asyncio + async def test_presigned_url_not_found(self, client, mock_s3_client): + """Test 404 for non-existent object.""" + mock_s3_client.object_exists.return_value = False + + response = await client.get( + '/upload/presigned-url', + params={'s3_key': 'nonexistent/key'}, + ) + + assert response.status_code == 404 + assert 'not found' in response.json()['detail'].lower() + + +# ============================================================================ +# Delete Staged Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestDeleteStagedEndpoint: + """Tests for DELETE /upload/staged/{doc_id} endpoint.""" + + @pytest.mark.asyncio + async def test_delete_removes_document(self, client, mock_s3_client): + """Test that delete removes the document.""" + mock_s3_client.list_staging.return_value = [ + {'key': 'staging/default/doc123/file.pdf', 'size': 1024, 'last_modified': '2024-01-01'}, + ] + + response = await client.delete('/upload/staged/doc123') + + assert response.status_code == 200 + data = response.json() + assert data['status'] == 'deleted' + assert data['doc_id'] == 'doc123' + assert data['deleted_count'] == '1' + + mock_s3_client.delete_object.assert_called_once_with('staging/default/doc123/file.pdf') + + @pytest.mark.asyncio + async def test_delete_not_found(self, client, mock_s3_client): + """Test 404 when document not found.""" + mock_s3_client.list_staging.return_value = [] + + response = await client.delete('/upload/staged/nonexistent') + + assert response.status_code == 404 + assert 'not found' in response.json()['detail'].lower() + + @pytest.mark.asyncio + async def test_delete_multiple_objects(self, client, mock_s3_client): + """Test deleting document with multiple S3 objects.""" + mock_s3_client.list_staging.return_value = [ + {'key': 'staging/default/doc456/part1.pdf', 'size': 1024, 'last_modified': '2024-01-01'}, + {'key': 'staging/default/doc456/part2.pdf', 'size': 2048, 'last_modified': '2024-01-01'}, + ] + + response = await client.delete('/upload/staged/doc456') + + assert response.status_code == 200 + assert response.json()['deleted_count'] == '2' + assert mock_s3_client.delete_object.call_count == 2 + + +# ============================================================================ +# Process S3 Endpoint Tests +# ============================================================================ + + +@pytest.mark.offline +class TestProcessS3Endpoint: + """Tests for POST /upload/process endpoint.""" + + @pytest.mark.asyncio + async def test_process_fetches_and_archives(self, client, mock_s3_client, mock_rag): + """Test that process fetches content and archives.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_object.return_value = (b'Document content here', {'content_type': 'text/plain'}) + mock_s3_client.get_s3_url.return_value = 's3://bucket/staging/default/doc1/file.txt' + mock_s3_client.move_to_archive.return_value = 'archive/default/doc1/file.txt' + mock_rag.ainsert.return_value = 'track_123' + + response = await client.post( + '/upload/process', + json={'s3_key': 'staging/default/doc1/file.txt'}, + ) + + assert response.status_code == 200 + data = response.json() + assert data['status'] == 'processing_complete' + assert data['track_id'] == 'track_123' + assert data['archive_key'] == 'archive/default/doc1/file.txt' + + mock_rag.ainsert.assert_called_once() + mock_s3_client.move_to_archive.assert_called_once() + + @pytest.mark.asyncio + async def test_process_extracts_doc_id_from_key(self, client, mock_s3_client, mock_rag): + """Test that doc_id is extracted from s3_key.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_object.return_value = (b'Content', {'content_type': 'text/plain'}) + mock_s3_client.get_s3_url.return_value = 's3://bucket/key' + mock_s3_client.move_to_archive.return_value = 'archive/key' + mock_rag.ainsert.return_value = 'track_456' + + response = await client.post( + '/upload/process', + json={'s3_key': 'staging/workspace1/extracted_doc_id/file.txt'}, + ) + + assert response.status_code == 200 + assert response.json()['doc_id'] == 'extracted_doc_id' + + @pytest.mark.asyncio + async def test_process_not_found(self, client, mock_s3_client, mock_rag): + """Test 404 when S3 object not found.""" + mock_s3_client.object_exists.return_value = False + + response = await client.post( + '/upload/process', + json={'s3_key': 'staging/default/missing/file.txt'}, + ) + + assert response.status_code == 404 + assert 'not found' in response.json()['detail'].lower() + + @pytest.mark.asyncio + async def test_process_empty_content_rejected(self, client, mock_s3_client, mock_rag): + """Test that empty content is rejected.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_object.return_value = (b' \n ', {'content_type': 'text/plain'}) + + response = await client.post( + '/upload/process', + json={'s3_key': 'staging/default/doc/file.txt'}, + ) + + assert response.status_code == 400 + assert 'empty' in response.json()['detail'].lower() + + @pytest.mark.asyncio + async def test_process_binary_content_rejected(self, client, mock_s3_client, mock_rag): + """Test that binary content that can't be decoded is rejected.""" + mock_s3_client.object_exists.return_value = True + # Invalid UTF-8 bytes + mock_s3_client.get_object.return_value = (b'\x80\x81\x82\x83', {'content_type': 'application/pdf'}) + + response = await client.post( + '/upload/process', + json={'s3_key': 'staging/default/doc/file.pdf'}, + ) + + assert response.status_code == 400 + assert 'binary' in response.json()['detail'].lower() + + @pytest.mark.asyncio + async def test_process_skip_archive(self, client, mock_s3_client, mock_rag): + """Test that archiving can be skipped.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_object.return_value = (b'Content', {'content_type': 'text/plain'}) + mock_s3_client.get_s3_url.return_value = 's3://bucket/key' + mock_rag.ainsert.return_value = 'track_789' + + response = await client.post( + '/upload/process', + json={ + 's3_key': 'staging/default/doc/file.txt', + 'archive_after_processing': False, + }, + ) + + assert response.status_code == 200 + assert response.json()['archive_key'] is None + mock_s3_client.move_to_archive.assert_not_called() + + @pytest.mark.asyncio + async def test_process_uses_provided_doc_id(self, client, mock_s3_client, mock_rag): + """Test that provided doc_id is used.""" + mock_s3_client.object_exists.return_value = True + mock_s3_client.get_object.return_value = (b'Content', {'content_type': 'text/plain'}) + mock_s3_client.get_s3_url.return_value = 's3://bucket/key' + mock_s3_client.move_to_archive.return_value = 'archive/key' + mock_rag.ainsert.return_value = 'track_999' + + response = await client.post( + '/upload/process', + json={ + 's3_key': 'staging/default/doc/file.txt', + 'doc_id': 'custom_doc_id', + }, + ) + + assert response.status_code == 200 + assert response.json()['doc_id'] == 'custom_doc_id' + + # Verify RAG was called with custom doc_id + mock_rag.ainsert.assert_called_once() + call_kwargs = mock_rag.ainsert.call_args.kwargs + assert call_kwargs['ids'] == 'custom_doc_id' diff --git a/uv.lock b/uv.lock index 033094c7..55b0fa2a 100644 --- a/uv.lock +++ b/uv.lock @@ -2730,6 +2730,7 @@ pytest = [ { name = "ruff" }, ] test = [ + { name = "aioboto3" }, { name = "aiofiles" }, { name = "aiohttp" }, { name = "ascii-colors" }, @@ -2745,6 +2746,7 @@ test = [ { name = "httpx" }, { name = "jiter" }, { name = "json-repair" }, + { name = "moto", extra = ["s3"] }, { name = "nano-vectordb" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2779,6 +2781,7 @@ test = [ [package.metadata] requires-dist = [ { name = "aioboto3", marker = "extra == 'offline-llm'", specifier = ">=12.0.0,<16.0.0" }, + { name = "aioboto3", marker = "extra == 'test'", specifier = ">=12.0.0,<16.0.0" }, { name = "aiofiles", marker = "extra == 'api'" }, { name = "aiohttp" }, { name = "aiohttp", marker = "extra == 'api'" }, @@ -2802,6 +2805,7 @@ requires-dist = [ { name = "gunicorn", marker = "extra == 'api'" }, { name = "httpcore", marker = "extra == 'api'" }, { name = "httpx", marker = "extra == 'api'", specifier = ">=0.28.1" }, + { name = "httpx", marker = "extra == 'test'", specifier = ">=0.27" }, { name = "jiter", marker = "extra == 'api'" }, { name = "json-repair" }, { name = "json-repair", marker = "extra == 'api'" }, @@ -2810,6 +2814,7 @@ requires-dist = [ { name = "lightrag-hku", extras = ["api"], marker = "extra == 'test'" }, { name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" }, { name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" }, + { name = "moto", extras = ["s3"], marker = "extra == 'test'", specifier = ">=5.0" }, { name = "nano-vectordb" }, { name = "nano-vectordb", marker = "extra == 'api'" }, { name = "neo4j", marker = "extra == 'offline-storage'", specifier = ">=5.0.0,<7.0.0" }, @@ -3332,6 +3337,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "moto" +version = "5.1.18" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3" }, + { name = "botocore" }, + { name = "cryptography" }, + { name = "jinja2" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "responses" }, + { name = "werkzeug" }, + { name = "xmltodict" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e3/6a/a73bef67261bfab55714390f07c7df97531d00cea730b7c0ace4d0ad7669/moto-5.1.18.tar.gz", hash = "sha256:45298ef7b88561b839f6fe3e9da2a6e2ecd10283c7bf3daf43a07a97465885f9", size = 8271655, upload-time = "2025-11-30T22:03:59.58Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/d4/6991df072b34741a0c115e8d21dc2fe142e4b497319d762e957f6677f001/moto-5.1.18-py3-none-any.whl", hash = "sha256:b65aa8fc9032c5c574415451e14fd7da4e43fd50b8bdcb5f10289ad382c25bcf", size = 6357278, upload-time = "2025-11-30T22:03:56.831Z" }, +] + +[package.optional-dependencies] +s3 = [ + { name = "py-partiql-parser" }, + { name = "pyyaml" }, +] + [[package]] name = "mpire" version = "2.10.2" @@ -4559,6 +4590,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/32/97ca2090f2f1b45b01b6aa7ae161cfe50671de097311975ca6eea3e7aabc/psutil-7.1.2-cp37-abi3-win_arm64.whl", hash = "sha256:3e988455e61c240cc879cb62a008c2699231bf3e3d061d7fce4234463fd2abb4", size = 243742, upload-time = "2025-10-25T10:47:17.302Z" }, ] +[[package]] +name = "py-partiql-parser" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/7a/a0f6bda783eb4df8e3dfd55973a1ac6d368a89178c300e1b5b91cd181e5e/py_partiql_parser-0.6.3.tar.gz", hash = "sha256:09cecf916ce6e3da2c050f0cb6106166de42c33d34a078ec2eb19377ea70389a", size = 17456, upload-time = "2025-10-18T13:56:13.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/33/a7cbfccc39056a5cf8126b7aab4c8bafbedd4f0ca68ae40ecb627a2d2cd3/py_partiql_parser-0.6.3-py2.py3-none-any.whl", hash = "sha256:deb0769c3346179d2f590dcbde556f708cdb929059fb654bad75f4cf6e07f582", size = 23752, upload-time = "2025-10-18T13:56:12.256Z" }, +] + [[package]] name = "pyarrow" version = "22.0.0" @@ -5553,6 +5593,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" }, ] +[[package]] +name = "responses" +version = "0.25.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/95/89c054ad70bfef6da605338b009b2e283485835351a9935c7bfbfaca7ffc/responses-0.25.8.tar.gz", hash = "sha256:9374d047a575c8f781b94454db5cab590b6029505f488d12899ddb10a4af1cf4", size = 79320, upload-time = "2025-08-08T19:01:46.709Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/4c/cc276ce57e572c102d9542d383b2cfd551276581dc60004cb94fe8774c11/responses-0.25.8-py3-none-any.whl", hash = "sha256:0c710af92def29c8352ceadff0c3fe340ace27cf5af1bbe46fb71275bcd2831c", size = 34769, upload-time = "2025-08-08T19:01:45.018Z" }, +] + [[package]] name = "rich" version = "13.9.4" @@ -6911,6 +6965,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, +] + [[package]] name = "wrapt" version = "1.17.3" @@ -6989,6 +7055,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/0c/3662f4a66880196a590b202f0db82d919dd2f89e99a27fadef91c4a33d41/xlsxwriter-3.2.9-py3-none-any.whl", hash = "sha256:9a5db42bc5dff014806c58a20b9eae7322a134abb6fce3c92c181bfb275ec5b3", size = 175315, upload-time = "2025-09-16T00:16:20.108Z" }, ] +[[package]] +name = "xmltodict" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/aa/917ceeed4dbb80d2f04dbd0c784b7ee7bba8ae5a54837ef0e5e062cd3cfb/xmltodict-1.0.2.tar.gz", hash = "sha256:54306780b7c2175a3967cad1db92f218207e5bc1aba697d887807c0fb68b7649", size = 25725, upload-time = "2025-09-17T21:59:26.459Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/20/69a0e6058bc5ea74892d089d64dfc3a62ba78917ec5e2cfa70f7c92ba3a5/xmltodict-1.0.2-py3-none-any.whl", hash = "sha256:62d0fddb0dcbc9f642745d8bbf4d81fd17d6dfaec5a15b5c1876300aad92af0d", size = 13893, upload-time = "2025-09-17T21:59:24.859Z" }, +] + [[package]] name = "xxhash" version = "3.6.0"