import asyncio import json import re import time from collections.abc import AsyncIterator, Awaitable, Callable from enum import Enum from typing import Any, TypeVar, cast from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from lightrag import LightRAG, QueryParam from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.base import OllamaServerInfos from lightrag.constants import DEFAULT_TOP_K from lightrag.utils import TiktokenTokenizer, logger # query mode according to query prefix (bypass is not LightRAG quer mode) class SearchMode(str, Enum): naive = 'naive' local = 'local' global_ = 'global' hybrid = 'hybrid' mix = 'mix' bypass = 'bypass' context = 'context' class OllamaMessage(BaseModel): role: str content: str images: list[str] | None = None class OllamaChatRequest(BaseModel): model: str messages: list[OllamaMessage] stream: bool = True options: dict[str, Any] | None = None system: str | None = None class OllamaChatResponse(BaseModel): model: str created_at: str message: OllamaMessage done: bool class OllamaGenerateRequest(BaseModel): model: str prompt: str system: str | None = None stream: bool = False options: dict[str, Any] | None = None class OllamaGenerateResponse(BaseModel): model: str created_at: str response: str done: bool context: list[int] | None total_duration: int | None load_duration: int | None prompt_eval_count: int | None prompt_eval_duration: int | None eval_count: int | None eval_duration: int | None class OllamaVersionResponse(BaseModel): version: str class OllamaModelDetails(BaseModel): parent_model: str format: str family: str families: list[str] parameter_size: str quantization_level: str class OllamaModel(BaseModel): name: str model: str size: int digest: str modified_at: str details: OllamaModelDetails class OllamaTagResponse(BaseModel): models: list[OllamaModel] class OllamaRunningModelDetails(BaseModel): parent_model: str format: str family: str families: list[str] parameter_size: str quantization_level: str class OllamaRunningModel(BaseModel): name: str model: str size: int digest: str details: OllamaRunningModelDetails expires_at: str size_vram: int class OllamaPsResponse(BaseModel): models: list[OllamaRunningModel] T = TypeVar('T', bound=BaseModel) async def parse_request_body(request: Request, model_class: type[T]) -> T: """ Parse request body based on Content-Type header. Supports both application/json and application/octet-stream. Args: request: The FastAPI Request object model_class: The Pydantic model class to parse the request into Returns: An instance of the provided model_class """ content_type = request.headers.get('content-type', '').lower() try: if content_type.startswith('application/json'): # FastAPI already handles JSON parsing for us body = await request.json() elif content_type.startswith('application/octet-stream'): # Manually parse octet-stream as JSON body_bytes = await request.body() body = json.loads(body_bytes.decode('utf-8')) else: # Try to parse as JSON for any other content type body_bytes = await request.body() body = json.loads(body_bytes.decode('utf-8')) # Create an instance of the model return model_class(**body) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail='Invalid JSON in request body') from e except Exception as e: raise HTTPException(status_code=400, detail=f'Error parsing request body: {e!s}') from e def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text using tiktoken""" tokens = TiktokenTokenizer().encode(text) return len(tokens) def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, str | None]: """Parse query prefix to determine search mode Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt) Examples: - "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams") - "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams") - "/local query string" -> (cleaned_query, SearchMode.local, False, None) """ # Initialize user_prompt as None user_prompt = None # First check if there's a bracket format for user prompt bracket_pattern = r'^/([a-z]*)\[(.*?)\](.*)' bracket_match = re.match(bracket_pattern, query) if bracket_match: mode_prefix = bracket_match.group(1) user_prompt = bracket_match.group(2) remaining_query = bracket_match.group(3).lstrip() # Reconstruct query, removing the bracket part query = f'/{mode_prefix} {remaining_query}'.strip() # Unified handling of mode and only_need_context determination mode_map = { '/local ': (SearchMode.local, False), '/global ': ( SearchMode.global_, False, ), # global_ is used because 'global' is a Python keyword '/naive ': (SearchMode.naive, False), '/hybrid ': (SearchMode.hybrid, False), '/mix ': (SearchMode.mix, False), '/bypass ': (SearchMode.bypass, False), '/context': ( SearchMode.mix, True, ), '/localcontext': (SearchMode.local, True), '/globalcontext': (SearchMode.global_, True), '/hybridcontext': (SearchMode.hybrid, True), '/naivecontext': (SearchMode.naive, True), '/mixcontext': (SearchMode.mix, True), } for prefix, (mode, only_need_context) in mode_map.items(): if query.startswith(prefix): # After removing prefix and leading spaces cleaned_query = query[len(prefix) :].lstrip() return cleaned_query, mode, only_need_context, user_prompt return query, SearchMode.mix, False, user_prompt class OllamaAPI: def __init__(self, rag: LightRAG, top_k: int = DEFAULT_TOP_K, api_key: str | None = None): self.rag = rag # Ensure server infos are always present for typing and runtime safety self.ollama_server_infos: OllamaServerInfos = rag.ollama_server_infos or OllamaServerInfos() self.top_k = top_k self.api_key = api_key self.router = APIRouter(tags=['ollama']) self.setup_routes() def setup_routes(self): # Create combined auth dependency for Ollama API routes combined_auth = get_combined_auth_dependency(self.api_key) @self.router.get('/version', dependencies=[Depends(combined_auth)]) async def get_version(): """Get Ollama version information""" return OllamaVersionResponse(version='0.9.3') @self.router.get('/tags', dependencies=[Depends(combined_auth)]) async def get_tags(): """Return available models acting as an Ollama server""" return OllamaTagResponse( models=[ OllamaModel( name=self.ollama_server_infos.LIGHTRAG_MODEL, model=self.ollama_server_infos.LIGHTRAG_MODEL, modified_at=self.ollama_server_infos.LIGHTRAG_CREATED_AT, size=self.ollama_server_infos.LIGHTRAG_SIZE, digest=self.ollama_server_infos.LIGHTRAG_DIGEST, details=OllamaModelDetails( parent_model='', format='gguf', family=self.ollama_server_infos.LIGHTRAG_NAME, families=[self.ollama_server_infos.LIGHTRAG_NAME], parameter_size='13B', quantization_level='Q4_0', ), ) ] ) @self.router.get('/ps', dependencies=[Depends(combined_auth)]) async def get_running_models(): """List Running Models - returns currently running models""" return OllamaPsResponse( models=[ OllamaRunningModel( name=self.ollama_server_infos.LIGHTRAG_MODEL, model=self.ollama_server_infos.LIGHTRAG_MODEL, size=self.ollama_server_infos.LIGHTRAG_SIZE, digest=self.ollama_server_infos.LIGHTRAG_DIGEST, details=OllamaRunningModelDetails( parent_model='', format='gguf', family='llama', families=['llama'], parameter_size='7.2B', quantization_level='Q4_0', ), expires_at='2050-12-31T14:38:31.83753-07:00', size_vram=self.ollama_server_infos.LIGHTRAG_SIZE, ) ] ) @self.router.post('/generate', dependencies=[Depends(combined_auth)], include_in_schema=True) async def generate(raw_request: Request): """Handle generate completion requests acting as an Ollama model For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model. Supports both application/json and application/octet-stream Content-Types. """ try: # Parse the request body manually request = await parse_request_body(raw_request, OllamaGenerateRequest) query = request.prompt start_time = time.time_ns() prompt_tokens = estimate_tokens(query) if request.system: self.rag.llm_model_kwargs['system_prompt'] = request.system if request.stream: llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func) response = cast( str | AsyncIterator[str], await llm_model_func(query, stream=True, **self.rag.llm_model_kwargs), ) async def stream_generator(): first_chunk_time = None last_chunk_time = time.time_ns() total_response = '' # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': response, 'done': False, } yield f'{json.dumps(data, ensure_ascii=False)}\n' completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': '', 'done': True, 'done_reason': 'stop', 'context': [], 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } yield f'{json.dumps(data, ensure_ascii=False)}\n' else: stream = cast(AsyncIterator[str], response) try: async for chunk in stream: if chunk: if first_chunk_time is None: first_chunk_time = time.time_ns() last_chunk_time = time.time_ns() total_response += chunk data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': chunk, 'done': False, } yield f'{json.dumps(data, ensure_ascii=False)}\n' except (asyncio.CancelledError, Exception) as e: error_msg = str(e) if isinstance(e, asyncio.CancelledError): error_msg = 'Stream was cancelled by server' else: error_msg = f'Provider error: {error_msg}' logger.error(f'Stream error: {error_msg}') # Send error message to client error_data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': f'\n\nError: {error_msg}', 'error': f'\n\nError: {error_msg}', 'done': False, } yield f'{json.dumps(error_data, ensure_ascii=False)}\n' # Send final message to close the stream final_data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': '', 'done': True, } yield f'{json.dumps(final_data, ensure_ascii=False)}\n' return if first_chunk_time is None: first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': '', 'done': True, 'done_reason': 'stop', 'context': [], 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } yield f'{json.dumps(data, ensure_ascii=False)}\n' return return StreamingResponse( stream_generator(), media_type='application/x-ndjson', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'Content-Type': 'application/x-ndjson', 'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy }, ) else: first_chunk_time = time.time_ns() llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func) response_text = cast(str, await llm_model_func(query, stream=False, **self.rag.llm_model_kwargs)) last_chunk_time = time.time_ns() if not response_text: response_text = 'No response generated' completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time return { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'response': str(response_text), 'done': True, 'done_reason': 'stop', 'context': [], 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } except Exception as e: logger.error(f'Ollama generate error: {e!s}', exc_info=True) raise HTTPException(status_code=500, detail=str(e)) from e @self.router.post('/chat', dependencies=[Depends(combined_auth)], include_in_schema=True) async def chat(raw_request: Request): """Process chat completion requests by acting as an Ollama model. Routes user queries through LightRAG by selecting query mode based on query prefix. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Supports both application/json and application/octet-stream Content-Types. """ try: # Parse the request body manually request = await parse_request_body(raw_request, OllamaChatRequest) # Get all messages messages = request.messages if not messages: raise HTTPException(status_code=400, detail='No messages provided') # Validate that the last message is from a user if messages[-1].role != 'user': raise HTTPException(status_code=400, detail='Last message must be from user role') # Get the last message as query and previous messages as history query = messages[-1].content # Convert OllamaMessage objects to dictionaries conversation_history = [{'role': msg.role, 'content': msg.content} for msg in messages[:-1]] # Check for query prefix cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(query) start_time = time.time_ns() prompt_tokens = estimate_tokens(cleaned_query) param_dict = { 'mode': mode.value, 'stream': request.stream, 'only_need_context': only_need_context, 'conversation_history': conversation_history, 'top_k': self.top_k, } # Add user_prompt to param_dict if user_prompt is not None: param_dict['user_prompt'] = user_prompt # Create QueryParam object from the parsed parameters query_param = QueryParam(**cast(dict[str, Any], param_dict)) # Execute query using the configured RAG instance # If stream is enabled, return StreamingResponse if request.stream: # Determine if the request is prefix with "/bypass" if mode == SearchMode.bypass: if request.system: self.rag.llm_model_kwargs['system_prompt'] = request.system llm_model_func = cast( Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func ) response = cast( str | AsyncIterator[str], await llm_model_func( cleaned_query, stream=True, history_messages=conversation_history, **self.rag.llm_model_kwargs, ), ) else: aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery) response = cast(str | AsyncIterator[str], await aquery_func(cleaned_query, param=query_param)) async def stream_generator(): first_chunk_time = None last_chunk_time = time.time_ns() total_response = '' # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': response, 'images': None, }, 'done': False, } yield f'{json.dumps(data, ensure_ascii=False)}\n' completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': '', 'images': None, }, 'done_reason': 'stop', 'done': True, 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } yield f'{json.dumps(data, ensure_ascii=False)}\n' else: stream = cast(AsyncIterator[str], response) try: async for chunk in stream: if chunk: if first_chunk_time is None: first_chunk_time = time.time_ns() last_chunk_time = time.time_ns() total_response += chunk data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': chunk, 'images': None, }, 'done': False, } yield f'{json.dumps(data, ensure_ascii=False)}\n' except (asyncio.CancelledError, Exception) as e: error_msg = str(e) if isinstance(e, asyncio.CancelledError): error_msg = 'Stream was cancelled by server' else: error_msg = f'Provider error: {error_msg}' logger.error(f'Stream error: {error_msg}') # Send error message to client error_data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': f'\n\nError: {error_msg}', 'images': None, }, 'error': f'\n\nError: {error_msg}', 'done': False, } yield f'{json.dumps(error_data, ensure_ascii=False)}\n' # Send final message to close the stream final_data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': '', 'images': None, }, 'done': True, } yield f'{json.dumps(final_data, ensure_ascii=False)}\n' return if first_chunk_time is None: first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time data = { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': '', 'images': None, }, 'done_reason': 'stop', 'done': True, 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } yield f'{json.dumps(data, ensure_ascii=False)}\n' return StreamingResponse( stream_generator(), media_type='application/x-ndjson', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'Content-Type': 'application/x-ndjson', 'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy }, ) else: first_chunk_time = time.time_ns() # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task match_result = re.search(r'\n\nUSER:', cleaned_query, re.MULTILINE) if match_result or mode == SearchMode.bypass: if request.system: self.rag.llm_model_kwargs['system_prompt'] = request.system llm_model_func = cast( Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func ) response_text = cast( str, await llm_model_func( cleaned_query, stream=False, history_messages=conversation_history, **self.rag.llm_model_kwargs, ), ) else: aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery) response_text = cast(str, await aquery_func(cleaned_query, param=query_param)) last_chunk_time = time.time_ns() if not response_text: response_text = 'No response generated' completion_tokens = estimate_tokens(str(response_text)) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time eval_time = last_chunk_time - first_chunk_time return { 'model': self.ollama_server_infos.LIGHTRAG_MODEL, 'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT, 'message': { 'role': 'assistant', 'content': str(response_text), 'images': None, }, 'done_reason': 'stop', 'done': True, 'total_duration': total_time, 'load_duration': 0, 'prompt_eval_count': prompt_tokens, 'prompt_eval_duration': prompt_eval_time, 'eval_count': completion_tokens, 'eval_duration': eval_time, } except Exception as e: logger.error(f'Ollama chat error: {e!s}', exc_info=True) raise HTTPException(status_code=500, detail=str(e)) from e