LightRAG/lightrag/llm/llama_index_impl.py
clssck 69358d830d test(lightrag,examples,api): comprehensive ruff formatting and type hints
Format entire codebase with ruff and add type hints across all modules:
- Apply ruff formatting to all Python files (121 files, 17K insertions)
- Add type hints to function signatures throughout lightrag core and API
- Update test suite with improved type annotations and docstrings
- Add pyrightconfig.json for static type checking configuration
- Create prompt_optimized.py and test_extraction_prompt_ab.py test files
- Update ruff.toml and .gitignore for improved linting configuration
- Standardize code style across examples, reproduce scripts, and utilities
2025-12-05 15:17:06 +01:00

194 lines
5.6 KiB
Python

from typing import Any
import pipmaster as pm
from llama_index.core.llms import (
ChatMessage,
ChatResponse,
MessageRole,
)
from lightrag.utils import logger
# Install required dependencies
if not pm.is_installed('llama-index'):
pm.install('llama-index')
import numpy as np
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.exceptions import (
APIConnectionError,
APITimeoutError,
RateLimitError,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
"""
Configure LlamaIndex settings.
Args:
settings: LlamaIndex Settings instance. If None, uses default settings.
**kwargs: Additional settings to override/configure
"""
if settings is None:
settings = LlamaIndexSettings()
# Update settings with any provided kwargs
for key, value in kwargs.items():
if hasattr(settings, key):
setattr(settings, key, value)
else:
logger.warning(f'Unknown LlamaIndex setting: {key}')
# Set as global settings
LlamaIndexSettings.set_global(settings)
return settings
def format_chat_messages(messages):
"""Format chat messages into LlamaIndex format."""
formatted_messages = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=content))
elif role == 'assistant':
formatted_messages.append(ChatMessage(role=MessageRole.ASSISTANT, content=content))
elif role == 'user':
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
else:
logger.warning(f'Unknown role {role}, treating as user message')
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
return formatted_messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def llama_index_complete_if_cache(
model: Any,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict] | None = None,
enable_cot: bool = False,
chat_kwargs=None,
) -> str:
"""Complete the prompt using LlamaIndex."""
if chat_kwargs is None:
chat_kwargs = {}
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug('enable_cot=True is not supported for LlamaIndex implementation and will be ignored.')
try:
# Format messages for chat
formatted_messages = []
# Add system message if provided
if system_prompt:
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=system_prompt))
# Add history messages
for msg in history_messages:
formatted_messages.append(
ChatMessage(
role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT,
content=msg['content'],
)
)
# Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
response: ChatResponse = await model.achat(messages=formatted_messages, **chat_kwargs)
# In newer versions, the response is in message.content
content = response.message.content
return content
except Exception as e:
logger.error(f'Error in llama_index_complete_if_cache: {e!s}')
raise
async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
) -> str:
"""
Main completion function for LlamaIndex
Args:
prompt: Input prompt
system_prompt: Optional system prompt
history_messages: Optional chat history
keyword_extraction: Whether to extract keywords from response
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if history_messages is None:
history_messages = []
kwargs.pop('keyword_extraction', None)
result = await llama_index_complete_if_cache(
kwargs.get('llm_instance'),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def llama_index_embed(
texts: list[str],
embed_model: BaseEmbedding = None,
settings: LlamaIndexSettings = None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings using LlamaIndex
Args:
texts: List of texts to embed
embed_model: LlamaIndex embedding model
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if settings:
configure_llama_index(settings)
if embed_model is None:
raise ValueError('embed_model must be provided')
# Use _get_text_embeddings for batch processing
embeddings = embed_model._get_text_embeddings(texts)
return np.array(embeddings)