Add extensive test suites for API routes and utilities: - Implement test_search_routes.py (406 lines) for search endpoint validation - Implement test_upload_routes.py (724 lines) for document upload workflows - Implement test_s3_client.py (618 lines) for S3 storage operations - Implement test_citation_utils.py (352 lines) for citation extraction - Implement test_chunking.py (216 lines) for text chunking validation Add S3 storage client implementation: - Create lightrag/storage/s3_client.py with S3 operations - Add storage module initialization with exports - Integrate S3 client with document upload handling Enhance API routes and core functionality: - Add search_routes.py with full-text and graph search endpoints - Add upload_routes.py with multipart document upload support - Update operate.py with bulk operations and health checks - Enhance postgres_impl.py with bulk upsert and parameterized queries - Update lightrag_server.py to register new API routes - Improve utils.py with citation and formatting utilities Update dependencies and configuration: - Add S3 and test dependencies to pyproject.toml - Update docker-compose.test.yml for testing environment - Sync uv.lock with new dependencies Apply code quality improvements across all modified files: - Add type hints to function signatures - Update imports and router initialization - Fix logging and error handling
195 lines
5.6 KiB
Python
195 lines
5.6 KiB
Python
from typing import Any
|
|
|
|
import pipmaster as pm
|
|
from llama_index.core.llms import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
MessageRole,
|
|
)
|
|
|
|
from lightrag.utils import logger
|
|
|
|
# Install required dependencies
|
|
if not pm.is_installed('llama-index'):
|
|
pm.install('llama-index')
|
|
|
|
import numpy as np
|
|
from llama_index.core.embeddings import BaseEmbedding
|
|
from llama_index.core.settings import Settings as LlamaIndexSettings
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from lightrag.exceptions import (
|
|
APIConnectionError,
|
|
APITimeoutError,
|
|
RateLimitError,
|
|
)
|
|
from lightrag.utils import (
|
|
wrap_embedding_func_with_attrs,
|
|
)
|
|
|
|
|
|
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
|
|
"""
|
|
Configure LlamaIndex settings.
|
|
|
|
Args:
|
|
settings: LlamaIndex Settings instance. If None, uses default settings.
|
|
**kwargs: Additional settings to override/configure
|
|
"""
|
|
if settings is None:
|
|
settings = LlamaIndexSettings()
|
|
|
|
# Update settings with any provided kwargs
|
|
for key, value in kwargs.items():
|
|
if hasattr(settings, key):
|
|
setattr(settings, key, value)
|
|
else:
|
|
logger.warning(f'Unknown LlamaIndex setting: {key}')
|
|
|
|
# Set as global settings
|
|
LlamaIndexSettings.set_global(settings)
|
|
return settings
|
|
|
|
|
|
def format_chat_messages(messages):
|
|
"""Format chat messages into LlamaIndex format."""
|
|
formatted_messages = []
|
|
|
|
for msg in messages:
|
|
role = msg.get('role', 'user')
|
|
content = msg.get('content', '')
|
|
|
|
if role == 'system':
|
|
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=content))
|
|
elif role == 'assistant':
|
|
formatted_messages.append(ChatMessage(role=MessageRole.ASSISTANT, content=content))
|
|
elif role == 'user':
|
|
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
|
|
else:
|
|
logger.warning(f'Unknown role {role}, treating as user message')
|
|
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
|
|
|
|
return formatted_messages
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
|
|
)
|
|
async def llama_index_complete_if_cache(
|
|
model: Any,
|
|
prompt: str,
|
|
system_prompt: str | None = None,
|
|
history_messages: list[dict] | None = None,
|
|
enable_cot: bool = False,
|
|
chat_kwargs=None,
|
|
) -> str:
|
|
"""Complete the prompt using LlamaIndex."""
|
|
if chat_kwargs is None:
|
|
chat_kwargs = {}
|
|
if history_messages is None:
|
|
history_messages = []
|
|
if enable_cot:
|
|
logger.debug('enable_cot=True is not supported for LlamaIndex implementation and will be ignored.')
|
|
try:
|
|
# Format messages for chat
|
|
formatted_messages = []
|
|
|
|
# Add system message if provided
|
|
if system_prompt:
|
|
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=system_prompt))
|
|
|
|
# Add history messages
|
|
for msg in history_messages:
|
|
formatted_messages.append(
|
|
ChatMessage(
|
|
role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT,
|
|
content=msg['content'],
|
|
)
|
|
)
|
|
|
|
# Add current prompt
|
|
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
|
|
|
|
response: ChatResponse = await model.achat(messages=formatted_messages, **chat_kwargs)
|
|
|
|
# In newer versions, the response is in message.content
|
|
content = response.message.content
|
|
return content
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error in llama_index_complete_if_cache: {e!s}')
|
|
raise
|
|
|
|
|
|
async def llama_index_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
enable_cot: bool = False,
|
|
keyword_extraction=False,
|
|
settings: LlamaIndexSettings = None,
|
|
**kwargs,
|
|
) -> str:
|
|
"""
|
|
Main completion function for LlamaIndex
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
system_prompt: Optional system prompt
|
|
history_messages: Optional chat history
|
|
keyword_extraction: Whether to extract keywords from response
|
|
settings: Optional LlamaIndex settings
|
|
**kwargs: Additional arguments
|
|
"""
|
|
if history_messages is None:
|
|
history_messages = []
|
|
|
|
kwargs.pop('keyword_extraction', None)
|
|
result = await llama_index_complete_if_cache(
|
|
kwargs.get('llm_instance'),
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
enable_cot=enable_cot,
|
|
**kwargs,
|
|
)
|
|
return result
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
|
|
)
|
|
async def llama_index_embed(
|
|
texts: list[str],
|
|
embed_model: BaseEmbedding = None,
|
|
settings: LlamaIndexSettings = None,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate embeddings using LlamaIndex
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
embed_model: LlamaIndex embedding model
|
|
settings: Optional LlamaIndex settings
|
|
**kwargs: Additional arguments
|
|
"""
|
|
if settings:
|
|
configure_llama_index(settings)
|
|
|
|
if embed_model is None:
|
|
raise ValueError('embed_model must be provided')
|
|
|
|
# Use _get_text_embeddings for batch processing
|
|
embeddings = embed_model._get_text_embeddings(texts)
|
|
return np.array(embeddings)
|