feat: 支持 LiteLLM 客户端的 application/octet-stream 请求

修改 Ollama API 路由处理程序,使 /chat 和 /generate 端点能够接受 Content-Type 为 application/octet-stream 的请求。通过绕过 FastAPI 的自动请求验证机制,手动解析请求体,解决了 LiteLLM 客户端连接时出现的 422 错误。此更改保持了对现有 application/json 请求的向后兼容性。
This commit is contained in:
yangdx 2025-06-11 13:42:30 +08:00
parent ad81e59d9a
commit 7b07d4c917

View file

@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Type
from lightrag.utils import logger from lightrag.utils import logger
import time import time
import json import json
@ -95,6 +95,47 @@ class OllamaTagResponse(BaseModel):
models: List[OllamaModel] models: List[OllamaModel]
async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel:
"""
Parse request body based on Content-Type header.
Supports both application/json and application/octet-stream.
Args:
request: The FastAPI Request object
model_class: The Pydantic model class to parse the request into
Returns:
An instance of the provided model_class
"""
content_type = request.headers.get("content-type", "").lower()
try:
if content_type.startswith("application/json"):
# FastAPI already handles JSON parsing for us
body = await request.json()
elif content_type.startswith("application/octet-stream"):
# Manually parse octet-stream as JSON
body_bytes = await request.body()
body = json.loads(body_bytes.decode('utf-8'))
else:
# Try to parse as JSON for any other content type
body_bytes = await request.body()
body = json.loads(body_bytes.decode('utf-8'))
# Create an instance of the model
return model_class(**body)
except json.JSONDecodeError:
raise HTTPException(
status_code=400,
detail="Invalid JSON in request body"
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error parsing request body: {str(e)}"
)
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text using tiktoken""" """Estimate the number of tokens in text using tiktoken"""
tokens = TiktokenTokenizer().encode(text) tokens = TiktokenTokenizer().encode(text)
@ -197,13 +238,17 @@ class OllamaAPI:
] ]
) )
@self.router.post("/generate", dependencies=[Depends(combined_auth)]) @self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True)
async def generate(raw_request: Request, request: OllamaGenerateRequest): async def generate(raw_request: Request):
"""Handle generate completion requests acting as an Ollama model """Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG, For compatibility purpose, the request is not processed by LightRAG,
and will be handled by underlying LLM model. and will be handled by underlying LLM model.
Supports both application/json and application/octet-stream Content-Types.
""" """
try: try:
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaGenerateRequest)
query = request.prompt query = request.prompt
start_time = time.time_ns() start_time = time.time_ns()
prompt_tokens = estimate_tokens(query) prompt_tokens = estimate_tokens(query)
@ -363,13 +408,17 @@ class OllamaAPI:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@self.router.post("/chat", dependencies=[Depends(combined_auth)]) @self.router.post("/chat", dependencies=[Depends(combined_auth)], include_in_schema=True)
async def chat(raw_request: Request, request: OllamaChatRequest): async def chat(raw_request: Request):
"""Process chat completion requests acting as an Ollama model """Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators. Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
Supports both application/json and application/octet-stream Content-Types.
""" """
try: try:
# Parse the request body manually
request = await parse_request_body(raw_request, OllamaChatRequest)
# Get all messages # Get all messages
messages = request.messages messages = request.messages
if not messages: if not messages: