LightRAG/lightrag/api/routers/ollama_api.py
clssck 082a5a8fad test(lightrag,api): add comprehensive test coverage and S3 support
Add extensive test suites for API routes and utilities:
- Implement test_search_routes.py (406 lines) for search endpoint validation
- Implement test_upload_routes.py (724 lines) for document upload workflows
- Implement test_s3_client.py (618 lines) for S3 storage operations
- Implement test_citation_utils.py (352 lines) for citation extraction
- Implement test_chunking.py (216 lines) for text chunking validation
Add S3 storage client implementation:
- Create lightrag/storage/s3_client.py with S3 operations
- Add storage module initialization with exports
- Integrate S3 client with document upload handling
Enhance API routes and core functionality:
- Add search_routes.py with full-text and graph search endpoints
- Add upload_routes.py with multipart document upload support
- Update operate.py with bulk operations and health checks
- Enhance postgres_impl.py with bulk upsert and parameterized queries
- Update lightrag_server.py to register new API routes
- Improve utils.py with citation and formatting utilities
Update dependencies and configuration:
- Add S3 and test dependencies to pyproject.toml
- Update docker-compose.test.yml for testing environment
- Sync uv.lock with new dependencies
Apply code quality improvements across all modified files:
- Add type hints to function signatures
- Update imports and router initialization
- Fix logging and error handling
2025-12-05 23:13:39 +01:00

727 lines
32 KiB
Python

import asyncio
import json
import re
import time
from collections.abc import AsyncIterator, Awaitable, Callable
from enum import Enum
from typing import Any, TypeVar, cast
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from lightrag.api.utils_api import get_combined_auth_dependency
from lightrag.base import OllamaServerInfos
from lightrag.constants import DEFAULT_TOP_K
from lightrag.utils import TiktokenTokenizer, logger
# query mode according to query prefix (bypass is not LightRAG quer mode)
class SearchMode(str, Enum):
naive = 'naive'
local = 'local'
global_ = 'global'
hybrid = 'hybrid'
mix = 'mix'
bypass = 'bypass'
context = 'context'
class OllamaMessage(BaseModel):
role: str
content: str
images: list[str] | None = None
class OllamaChatRequest(BaseModel):
model: str
messages: list[OllamaMessage]
stream: bool = True
options: dict[str, Any] | None = None
system: str | None = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str
prompt: str
system: str | None = None
stream: bool = False
options: dict[str, Any] | None = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
response: str
done: bool
context: list[int] | None
total_duration: int | None
load_duration: int | None
prompt_eval_count: int | None
prompt_eval_duration: int | None
eval_count: int | None
eval_duration: int | None
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: list[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: list[OllamaModel]
class OllamaRunningModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: list[str]
parameter_size: str
quantization_level: str
class OllamaRunningModel(BaseModel):
name: str
model: str
size: int
digest: str
details: OllamaRunningModelDetails
expires_at: str
size_vram: int
class OllamaPsResponse(BaseModel):
models: list[OllamaRunningModel]
T = TypeVar('T', bound=BaseModel)
async def parse_request_body(request: Request, model_class: type[T]) -> T:
"""
Parse request body based on Content-Type header.
Supports both application/json and application/octet-stream.
Args:
request: The FastAPI Request object
model_class: The Pydantic model class to parse the request into
Returns:
An instance of the provided model_class
"""
content_type = request.headers.get('content-type', '').lower()
try:
if content_type.startswith('application/json'):
# FastAPI already handles JSON parsing for us
body = await request.json()
elif content_type.startswith('application/octet-stream'):
# Manually parse octet-stream as JSON
body_bytes = await request.body()
body = json.loads(body_bytes.decode('utf-8'))
else:
# Try to parse as JSON for any other content type
body_bytes = await request.body()
body = json.loads(body_bytes.decode('utf-8'))
# Create an instance of the model
return model_class(**body)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail='Invalid JSON in request body') from e
except Exception as e:
raise HTTPException(status_code=400, detail=f'Error parsing request body: {e!s}') from e
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using tiktoken"""
tokens = TiktokenTokenizer().encode(text)
return len(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, str | None]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
Examples:
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
"""
# Initialize user_prompt as None
user_prompt = None
# First check if there's a bracket format for user prompt
bracket_pattern = r'^/([a-z]*)\[(.*?)\](.*)'
bracket_match = re.match(bracket_pattern, query)
if bracket_match:
mode_prefix = bracket_match.group(1)
user_prompt = bracket_match.group(2)
remaining_query = bracket_match.group(3).lstrip()
# Reconstruct query, removing the bracket part
query = f'/{mode_prefix} {remaining_query}'.strip()
# Unified handling of mode and only_need_context determination
mode_map = {
'/local ': (SearchMode.local, False),
'/global ': (
SearchMode.global_,
False,
), # global_ is used because 'global' is a Python keyword
'/naive ': (SearchMode.naive, False),
'/hybrid ': (SearchMode.hybrid, False),
'/mix ': (SearchMode.mix, False),
'/bypass ': (SearchMode.bypass, False),
'/context': (
SearchMode.mix,
True,
),
'/localcontext': (SearchMode.local, True),
'/globalcontext': (SearchMode.global_, True),
'/hybridcontext': (SearchMode.hybrid, True),
'/naivecontext': (SearchMode.naive, True),
'/mixcontext': (SearchMode.mix, True),
}
for prefix, (mode, only_need_context) in mode_map.items():
if query.startswith(prefix):
# After removing prefix and leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode, only_need_context, user_prompt
return query, SearchMode.mix, False, user_prompt
class OllamaAPI:
def __init__(self, rag: LightRAG, top_k: int = DEFAULT_TOP_K, api_key: str | None = None):
self.rag = rag
# Ensure server infos are always present for typing and runtime safety
self.ollama_server_infos: OllamaServerInfos = rag.ollama_server_infos or OllamaServerInfos()
self.top_k = top_k
self.api_key = api_key
self.router = APIRouter(tags=['ollama'])
self.setup_routes()
def setup_routes(self):
# Create combined auth dependency for Ollama API routes
combined_auth = get_combined_auth_dependency(self.api_key)
@self.router.get('/version', dependencies=[Depends(combined_auth)])
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version='0.9.3')
@self.router.get('/tags', dependencies=[Depends(combined_auth)])
async def get_tags():
"""Return available models acting as an Ollama server"""
return OllamaTagResponse(
models=[
OllamaModel(
name=self.ollama_server_infos.LIGHTRAG_MODEL,
model=self.ollama_server_infos.LIGHTRAG_MODEL,
modified_at=self.ollama_server_infos.LIGHTRAG_CREATED_AT,
size=self.ollama_server_infos.LIGHTRAG_SIZE,
digest=self.ollama_server_infos.LIGHTRAG_DIGEST,
details=OllamaModelDetails(
parent_model='',
format='gguf',
family=self.ollama_server_infos.LIGHTRAG_NAME,
families=[self.ollama_server_infos.LIGHTRAG_NAME],
parameter_size='13B',
quantization_level='Q4_0',
),
)
]
)
@self.router.get('/ps', dependencies=[Depends(combined_auth)])
async def get_running_models():
"""List Running Models - returns currently running models"""
return OllamaPsResponse(
models=[
OllamaRunningModel(
name=self.ollama_server_infos.LIGHTRAG_MODEL,
model=self.ollama_server_infos.LIGHTRAG_MODEL,
size=self.ollama_server_infos.LIGHTRAG_SIZE,
digest=self.ollama_server_infos.LIGHTRAG_DIGEST,
details=OllamaRunningModelDetails(
parent_model='',
format='gguf',
family='llama',
families=['llama'],
parameter_size='7.2B',
quantization_level='Q4_0',
),
expires_at='2050-12-31T14:38:31.83753-07:00',
size_vram=self.ollama_server_infos.LIGHTRAG_SIZE,
)
]
)
@self.router.post('/generate', dependencies=[Depends(combined_auth)], include_in_schema=True)
async def generate(raw_request: Request):
"""Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
Supports both application/json and application/octet-stream Content-Types.
"""
try:
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaGenerateRequest)
query = request.prompt
start_time = time.time_ns()
prompt_tokens = estimate_tokens(query)
if request.system:
self.rag.llm_model_kwargs['system_prompt'] = request.system
if request.stream:
llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func)
response = cast(
str | AsyncIterator[str],
await llm_model_func(query, stream=True, **self.rag.llm_model_kwargs),
)
async def stream_generator():
first_chunk_time = None
last_chunk_time = time.time_ns()
total_response = ''
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = response
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': response,
'done': False,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': '',
'done': True,
'done_reason': 'stop',
'context': [],
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
else:
stream = cast(AsyncIterator[str], response)
try:
async for chunk in stream:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': chunk,
'done': False,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = 'Stream was cancelled by server'
else:
error_msg = f'Provider error: {error_msg}'
logger.error(f'Stream error: {error_msg}')
# Send error message to client
error_data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': f'\n\nError: {error_msg}',
'error': f'\n\nError: {error_msg}',
'done': False,
}
yield f'{json.dumps(error_data, ensure_ascii=False)}\n'
# Send final message to close the stream
final_data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': '',
'done': True,
}
yield f'{json.dumps(final_data, ensure_ascii=False)}\n'
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': '',
'done': True,
'done_reason': 'stop',
'context': [],
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
return
return StreamingResponse(
stream_generator(),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/x-ndjson',
'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy
},
)
else:
first_chunk_time = time.time_ns()
llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func)
response_text = cast(str, await llm_model_func(query, stream=False, **self.rag.llm_model_kwargs))
last_chunk_time = time.time_ns()
if not response_text:
response_text = 'No response generated'
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'response': str(response_text),
'done': True,
'done_reason': 'stop',
'context': [],
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
except Exception as e:
logger.error(f'Ollama generate error: {e!s}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) from e
@self.router.post('/chat', dependencies=[Depends(combined_auth)], include_in_schema=True)
async def chat(raw_request: Request):
"""Process chat completion requests by acting as an Ollama model.
Routes user queries through LightRAG by selecting query mode based on query prefix.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types.
"""
try:
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaChatRequest)
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail='No messages provided')
# Validate that the last message is from a user
if messages[-1].role != 'user':
raise HTTPException(status_code=400, detail='Last message must be from user role')
# Get the last message as query and previous messages as history
query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [{'role': msg.role, 'content': msg.content} for msg in messages[:-1]]
# Check for query prefix
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(query)
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
param_dict = {
'mode': mode.value,
'stream': request.stream,
'only_need_context': only_need_context,
'conversation_history': conversation_history,
'top_k': self.top_k,
}
# Add user_prompt to param_dict
if user_prompt is not None:
param_dict['user_prompt'] = user_prompt
query_param = QueryParam(**cast(Any, param_dict))
if request.stream:
# Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs['system_prompt'] = request.system
llm_model_func = cast(
Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func
)
response = cast(
str | AsyncIterator[str],
await llm_model_func(
cleaned_query,
stream=True,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
),
)
else:
aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery)
response = cast(str | AsyncIterator[str], await aquery_func(cleaned_query, param=query_param))
async def stream_generator():
first_chunk_time = None
last_chunk_time = time.time_ns()
total_response = ''
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = start_time
last_chunk_time = time.time_ns()
total_response = response
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': response,
'images': None,
},
'done': False,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': '',
'images': None,
},
'done_reason': 'stop',
'done': True,
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
else:
stream = cast(AsyncIterator[str], response)
try:
async for chunk in stream:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': chunk,
'images': None,
},
'done': False,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = 'Stream was cancelled by server'
else:
error_msg = f'Provider error: {error_msg}'
logger.error(f'Stream error: {error_msg}')
# Send error message to client
error_data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': f'\n\nError: {error_msg}',
'images': None,
},
'error': f'\n\nError: {error_msg}',
'done': False,
}
yield f'{json.dumps(error_data, ensure_ascii=False)}\n'
# Send final message to close the stream
final_data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': '',
'images': None,
},
'done': True,
}
yield f'{json.dumps(final_data, ensure_ascii=False)}\n'
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': '',
'images': None,
},
'done_reason': 'stop',
'done': True,
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
yield f'{json.dumps(data, ensure_ascii=False)}\n'
return StreamingResponse(
stream_generator(),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Content-Type': 'application/x-ndjson',
'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy
},
)
else:
first_chunk_time = time.time_ns()
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
if match_result or mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs['system_prompt'] = request.system
llm_model_func = cast(
Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func
)
response_text = cast(
str,
await llm_model_func(
cleaned_query,
stream=False,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
),
)
else:
aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery)
response_text = cast(str, await aquery_func(cleaned_query, param=query_param))
last_chunk_time = time.time_ns()
if not response_text:
response_text = 'No response generated'
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
'message': {
'role': 'assistant',
'content': str(response_text),
'images': None,
},
'done_reason': 'stop',
'done': True,
'total_duration': total_time,
'load_duration': 0,
'prompt_eval_count': prompt_tokens,
'prompt_eval_duration': prompt_eval_time,
'eval_count': completion_tokens,
'eval_duration': eval_time,
}
except Exception as e:
logger.error(f'Ollama chat error: {e!s}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) from e