Add extensive test suites for API routes and utilities: - Implement test_search_routes.py (406 lines) for search endpoint validation - Implement test_upload_routes.py (724 lines) for document upload workflows - Implement test_s3_client.py (618 lines) for S3 storage operations - Implement test_citation_utils.py (352 lines) for citation extraction - Implement test_chunking.py (216 lines) for text chunking validation Add S3 storage client implementation: - Create lightrag/storage/s3_client.py with S3 operations - Add storage module initialization with exports - Integrate S3 client with document upload handling Enhance API routes and core functionality: - Add search_routes.py with full-text and graph search endpoints - Add upload_routes.py with multipart document upload support - Update operate.py with bulk operations and health checks - Enhance postgres_impl.py with bulk upsert and parameterized queries - Update lightrag_server.py to register new API routes - Improve utils.py with citation and formatting utilities Update dependencies and configuration: - Add S3 and test dependencies to pyproject.toml - Update docker-compose.test.yml for testing environment - Sync uv.lock with new dependencies Apply code quality improvements across all modified files: - Add type hints to function signatures - Update imports and router initialization - Fix logging and error handling
307 lines
12 KiB
Python
307 lines
12 KiB
Python
import asyncio
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Any, final
|
|
|
|
import numpy as np
|
|
from chromadb import HttpClient, PersistentClient # type: ignore
|
|
from chromadb.config import Settings # type: ignore
|
|
|
|
from lightrag.base import BaseVectorStorage
|
|
from lightrag.utils import logger
|
|
|
|
|
|
@final
|
|
@dataclass
|
|
class ChromaVectorDBStorage(BaseVectorStorage):
|
|
"""ChromaDB vector storage implementation."""
|
|
|
|
def __post_init__(self):
|
|
try:
|
|
config = self.global_config.get('vector_db_storage_cls_kwargs', {})
|
|
cosine_threshold = config.get('cosine_better_than_threshold')
|
|
if cosine_threshold is None:
|
|
raise ValueError('cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs')
|
|
self.cosine_better_than_threshold = cosine_threshold
|
|
|
|
user_collection_settings = config.get('collection_settings', {})
|
|
# Default HNSW index settings for ChromaDB
|
|
default_collection_settings = {
|
|
# Distance metric used for similarity search (cosine similarity)
|
|
'hnsw:space': 'cosine',
|
|
# Number of nearest neighbors to explore during index construction
|
|
# Higher values = better recall but slower indexing
|
|
'hnsw:construction_ef': 128,
|
|
# Number of nearest neighbors to explore during search
|
|
# Higher values = better recall but slower search
|
|
'hnsw:search_ef': 128,
|
|
# Number of connections per node in the HNSW graph
|
|
# Higher values = better recall but more memory usage
|
|
'hnsw:M': 16,
|
|
# Number of vectors to process in one batch during indexing
|
|
'hnsw:batch_size': 100,
|
|
# Number of updates before forcing index synchronization
|
|
# Lower values = more frequent syncs but slower indexing
|
|
'hnsw:sync_threshold': 1000,
|
|
}
|
|
collection_settings = {
|
|
**default_collection_settings,
|
|
**user_collection_settings,
|
|
}
|
|
|
|
local_path = config.get('local_path', None)
|
|
if local_path:
|
|
self._client = PersistentClient(
|
|
path=local_path,
|
|
settings=Settings(
|
|
allow_reset=True,
|
|
anonymized_telemetry=False,
|
|
),
|
|
)
|
|
else:
|
|
auth_provider = config.get('auth_provider', 'chromadb.auth.token_authn.TokenAuthClientProvider')
|
|
auth_credentials = config.get('auth_token', 'secret-token')
|
|
headers = {}
|
|
|
|
if 'token_authn' in auth_provider:
|
|
headers = {config.get('auth_header_name', 'X-Chroma-Token'): auth_credentials}
|
|
elif 'basic_authn' in auth_provider:
|
|
auth_credentials = config.get('auth_credentials', 'admin:admin')
|
|
|
|
self._client = HttpClient(
|
|
host=config.get('host', 'localhost'),
|
|
port=config.get('port', 8000),
|
|
headers=headers,
|
|
settings=Settings(
|
|
chroma_api_impl='rest',
|
|
chroma_client_auth_provider=auth_provider,
|
|
chroma_client_auth_credentials=auth_credentials,
|
|
allow_reset=True,
|
|
anonymized_telemetry=False,
|
|
),
|
|
)
|
|
|
|
self._collection = self._client.get_or_create_collection(
|
|
name=self.namespace,
|
|
metadata={
|
|
**collection_settings,
|
|
'dimension': self.embedding_func.embedding_dim,
|
|
},
|
|
)
|
|
# Use batch size from collection settings if specified
|
|
self._max_batch_size = self.global_config.get(
|
|
'embedding_batch_num', collection_settings.get('hnsw:batch_size', 32)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f'ChromaDB initialization failed: {e!s}')
|
|
raise
|
|
|
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
logger.debug(f'Inserting {len(data)} to {self.namespace}')
|
|
if not data:
|
|
return
|
|
|
|
try:
|
|
import time
|
|
|
|
current_time = int(time.time())
|
|
|
|
ids = list(data.keys())
|
|
documents = [v['content'] for v in data.values()]
|
|
metadatas = [
|
|
{**{k: v for k, v in item.items() if k in self.meta_fields}, 'created_at': current_time}
|
|
for item in data.values()
|
|
]
|
|
|
|
# Process in batches
|
|
batches = [documents[i : i + self._max_batch_size] for i in range(0, len(documents), self._max_batch_size)]
|
|
|
|
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
|
embeddings_list = []
|
|
|
|
# Pre-allocate embeddings_list with known size
|
|
embeddings_list = [None] * len(embedding_tasks)
|
|
|
|
# Use asyncio.gather instead of as_completed if order doesn't matter
|
|
embeddings_results = await asyncio.gather(*embedding_tasks)
|
|
embeddings_list = list(embeddings_results)
|
|
|
|
embeddings = np.concatenate(embeddings_list)
|
|
|
|
# Upsert in batches
|
|
for i in range(0, len(ids), self._max_batch_size):
|
|
batch_slice = slice(i, i + self._max_batch_size)
|
|
|
|
self._collection.upsert(
|
|
ids=ids[batch_slice],
|
|
embeddings=embeddings[batch_slice].tolist(),
|
|
documents=documents[batch_slice],
|
|
metadatas=metadatas[batch_slice],
|
|
)
|
|
|
|
return ids
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error during ChromaDB upsert: {e!s}')
|
|
raise
|
|
|
|
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
|
try:
|
|
embedding = await self.embedding_func([query], _priority=5) # higher priority for query
|
|
|
|
results = self._collection.query(
|
|
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
|
n_results=top_k * 2, # Request more results to allow for filtering
|
|
include=['metadatas', 'distances', 'documents'],
|
|
)
|
|
|
|
# Filter results by cosine similarity threshold and take top k
|
|
# We request 2x results initially to have enough after filtering
|
|
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
|
|
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
|
|
# Only keep results with distance below threshold, then take top k
|
|
return [
|
|
{
|
|
'id': results['ids'][0][i],
|
|
'distance': 1 - results['distances'][0][i],
|
|
'content': results['documents'][0][i],
|
|
'created_at': results['metadatas'][0][i].get('created_at'),
|
|
**results['metadatas'][0][i],
|
|
}
|
|
for i in range(len(results['ids'][0]))
|
|
if (1 - results['distances'][0][i]) >= self.cosine_better_than_threshold
|
|
][:top_k]
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error during ChromaDB query: {e!s}')
|
|
raise
|
|
|
|
async def index_done_callback(self) -> None:
|
|
# ChromaDB handles persistence automatically
|
|
pass
|
|
|
|
async def delete_entity(self, entity_name: str) -> None:
|
|
"""Delete an entity by its ID.
|
|
|
|
Args:
|
|
entity_name: The ID of the entity to delete
|
|
"""
|
|
try:
|
|
logger.info(f'Deleting entity with ID {entity_name} from {self.namespace}')
|
|
self._collection.delete(ids=[entity_name])
|
|
except Exception as e:
|
|
logger.error(f'Error during entity deletion: {e!s}')
|
|
raise
|
|
|
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
|
"""Delete an entity and its relations by ID.
|
|
In vector DB context, this is equivalent to delete_entity.
|
|
|
|
Args:
|
|
entity_name: The ID of the entity to delete
|
|
"""
|
|
await self.delete_entity(entity_name)
|
|
|
|
async def delete(self, ids: list[str]) -> None:
|
|
"""Delete vectors with specified IDs
|
|
|
|
Args:
|
|
ids: List of vector IDs to be deleted
|
|
"""
|
|
try:
|
|
self._collection.delete(ids=ids)
|
|
logger.debug(f'Successfully deleted {len(ids)} vectors from {self.namespace}')
|
|
except Exception as e:
|
|
logger.error(f'Error while deleting vectors from {self.namespace}: {e!s}')
|
|
raise
|
|
|
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
"""Get vector data by its ID
|
|
|
|
Args:
|
|
id: The unique identifier of the vector
|
|
|
|
Returns:
|
|
The vector data if found, or None if not found
|
|
"""
|
|
try:
|
|
# Query the collection for a single vector by ID
|
|
result = self._collection.get(ids=[id], include=['metadatas', 'embeddings', 'documents'])
|
|
|
|
if not result or not result['ids'] or len(result['ids']) == 0:
|
|
return None
|
|
|
|
# Format the result to match the expected structure
|
|
return {
|
|
'id': result['ids'][0],
|
|
'vector': result['embeddings'][0],
|
|
'content': result['documents'][0],
|
|
'created_at': result['metadatas'][0].get('created_at'),
|
|
**result['metadatas'][0],
|
|
}
|
|
except Exception as e:
|
|
logger.error(f'Error retrieving vector data for ID {id}: {e!s}')
|
|
return None
|
|
|
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
"""Get multiple vector data by their IDs
|
|
|
|
Args:
|
|
ids: List of unique identifiers
|
|
|
|
Returns:
|
|
List of vector data objects that were found
|
|
"""
|
|
if not ids:
|
|
return []
|
|
|
|
try:
|
|
# Query the collection for multiple vectors by IDs
|
|
result = self._collection.get(ids=ids, include=['metadatas', 'embeddings', 'documents'])
|
|
|
|
if not result or not result['ids'] or len(result['ids']) == 0:
|
|
return []
|
|
|
|
# Format the results to match the expected structure and preserve ordering
|
|
formatted_map: dict[str, dict[str, Any]] = {}
|
|
for i, result_id in enumerate(result['ids']):
|
|
record = {
|
|
'id': result_id,
|
|
'vector': result['embeddings'][i],
|
|
'content': result['documents'][i],
|
|
'created_at': result['metadatas'][i].get('created_at'),
|
|
**result['metadatas'][i],
|
|
}
|
|
formatted_map[str(result_id)] = record
|
|
|
|
ordered_results: list[dict[str, Any] | None] = []
|
|
for requested_id in ids:
|
|
ordered_results.append(formatted_map.get(str(requested_id)))
|
|
|
|
return ordered_results
|
|
except Exception as e:
|
|
logger.error(f'Error retrieving vector data for IDs {ids}: {e}')
|
|
return []
|
|
|
|
async def drop(self) -> dict[str, str]:
|
|
"""Drop all vector data from storage and clean up resources
|
|
|
|
This method will delete all documents from the ChromaDB collection.
|
|
|
|
Returns:
|
|
dict[str, str]: Operation status and message
|
|
- On success: {"status": "success", "message": "data dropped"}
|
|
- On failure: {"status": "error", "message": "<error details>"}
|
|
"""
|
|
try:
|
|
# Get all IDs in the collection
|
|
result = self._collection.get(include=[])
|
|
if result and result['ids'] and len(result['ids']) > 0:
|
|
# Delete all documents
|
|
self._collection.delete(ids=result['ids'])
|
|
|
|
logger.info(f'Process {os.getpid()} drop ChromaDB collection {self.namespace}')
|
|
return {'status': 'success', 'message': 'data dropped'}
|
|
except Exception as e:
|
|
logger.error(f'Error dropping ChromaDB collection {self.namespace}: {e}')
|
|
return {'status': 'error', 'message': str(e)}
|