diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 8d0895bb..625c0162 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Type from lightrag.utils import logger import time import json @@ -95,6 +95,47 @@ class OllamaTagResponse(BaseModel): models: List[OllamaModel] +async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel: + """ + Parse request body based on Content-Type header. + Supports both application/json and application/octet-stream. + + Args: + request: The FastAPI Request object + model_class: The Pydantic model class to parse the request into + + Returns: + An instance of the provided model_class + """ + content_type = request.headers.get("content-type", "").lower() + + try: + if content_type.startswith("application/json"): + # FastAPI already handles JSON parsing for us + body = await request.json() + elif content_type.startswith("application/octet-stream"): + # Manually parse octet-stream as JSON + body_bytes = await request.body() + body = json.loads(body_bytes.decode('utf-8')) + else: + # Try to parse as JSON for any other content type + body_bytes = await request.body() + body = json.loads(body_bytes.decode('utf-8')) + + # Create an instance of the model + return model_class(**body) + except json.JSONDecodeError: + raise HTTPException( + status_code=400, + detail="Invalid JSON in request body" + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Error parsing request body: {str(e)}" + ) + + def estimate_tokens(text: str) -> int: """Estimate the number of tokens in text using tiktoken""" tokens = TiktokenTokenizer().encode(text) @@ -197,13 +238,17 @@ class OllamaAPI: ] ) - @self.router.post("/generate", dependencies=[Depends(combined_auth)]) - async def generate(raw_request: Request, request: OllamaGenerateRequest): + @self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True) + async def generate(raw_request: Request): """Handle generate completion requests acting as an Ollama model For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model. + Supports both application/json and application/octet-stream Content-Types. """ try: + # Parse the request body manually + request = await parse_request_body(raw_request, OllamaGenerateRequest) + query = request.prompt start_time = time.time_ns() prompt_tokens = estimate_tokens(query) @@ -363,13 +408,17 @@ class OllamaAPI: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) - @self.router.post("/chat", dependencies=[Depends(combined_auth)]) - async def chat(raw_request: Request, request: OllamaChatRequest): + @self.router.post("/chat", dependencies=[Depends(combined_auth)], include_in_schema=True) + async def chat(raw_request: Request): """Process chat completion requests acting as an Ollama model Routes user queries through LightRAG by selecting query mode based on prefix indicators. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. + Supports both application/json and application/octet-stream Content-Types. """ try: + # Parse the request body manually + request = await parse_request_body(raw_request, OllamaChatRequest) + # Get all messages messages = request.messages if not messages: