LightRAG/lightrag/llm/llama_index_impl.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

184 lines
5.4 KiB
Python

from __future__ import annotations
from typing import Any
import pipmaster as pm
from llama_index.core.llms import (
ChatMessage,
ChatResponse,
MessageRole,
)
from lightrag.utils import logger
# Install required dependencies
if not pm.is_installed('llama-index'):
pm.install('llama-index')
import numpy as np
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.exceptions import (
APIConnectionError,
APITimeoutError,
RateLimitError,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
def configure_llama_index(**kwargs) -> None:
"""
Configure LlamaIndex global settings.
Args:
**kwargs: Settings to configure on the global Settings singleton.
Common settings: llm, embed_model, chunk_size, chunk_overlap
"""
# LlamaIndexSettings is a singleton - configure it directly
for key, value in kwargs.items():
if hasattr(LlamaIndexSettings, key):
setattr(LlamaIndexSettings, key, value)
else:
logger.warning(f'Unknown LlamaIndex setting: {key}')
def format_chat_messages(messages):
"""Format chat messages into LlamaIndex format."""
formatted_messages = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=content))
elif role == 'assistant':
formatted_messages.append(ChatMessage(role=MessageRole.ASSISTANT, content=content))
elif role == 'user':
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
else:
logger.warning(f'Unknown role {role}, treating as user message')
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=content))
return formatted_messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def llama_index_complete_if_cache(
model: Any,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict] | None = None,
enable_cot: bool = False,
chat_kwargs=None,
) -> str:
"""Complete the prompt using LlamaIndex."""
if chat_kwargs is None:
chat_kwargs = {}
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug('enable_cot=True is not supported for LlamaIndex implementation and will be ignored.')
try:
# Format messages for chat
formatted_messages = []
# Add system message if provided
if system_prompt:
formatted_messages.append(ChatMessage(role=MessageRole.SYSTEM, content=system_prompt))
# Add history messages
for msg in history_messages:
formatted_messages.append(
ChatMessage(
role=MessageRole.USER if msg['role'] == 'user' else MessageRole.ASSISTANT,
content=msg['content'],
)
)
# Add current prompt
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
response: ChatResponse = await model.achat(messages=formatted_messages, **chat_kwargs)
# In newer versions, the response is in message.content
content = response.message.content
return content or ''
except Exception as e:
logger.error(f'Error in llama_index_complete_if_cache: {e!s}')
raise
async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
"""
Main completion function for LlamaIndex
Args:
prompt: Input prompt
system_prompt: Optional system prompt
history_messages: Optional chat history
keyword_extraction: Whether to extract keywords from response
settings: Optional LlamaIndex settings
**kwargs: Additional arguments
"""
if history_messages is None:
history_messages = []
kwargs.pop('keyword_extraction', None)
result = await llama_index_complete_if_cache(
kwargs.get('llm_instance'),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def llama_index_embed(
texts: list[str],
embed_model: BaseEmbedding | None = None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings using LlamaIndex
Args:
texts: List of texts to embed
embed_model: LlamaIndex embedding model
**kwargs: Additional arguments
"""
if embed_model is None:
raise ValueError('embed_model must be provided')
# Use _get_text_embeddings for batch processing
embeddings = embed_model._get_text_embeddings(texts)
return np.array(embeddings)