feat: 支持 LiteLLM 客户端的 application/octet-stream 请求
修改 Ollama API 路由处理程序,使 /chat 和 /generate 端点能够接受 Content-Type 为 application/octet-stream 的请求。通过绕过 FastAPI 的自动请求验证机制,手动解析请求体,解决了 LiteLLM 客户端连接时出现的 422 错误。此更改保持了对现有 application/json 请求的向后兼容性。
This commit is contained in:
parent
ad81e59d9a
commit
7b07d4c917
1 changed files with 54 additions and 5 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Type
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
@ -95,6 +95,47 @@ class OllamaTagResponse(BaseModel):
|
||||||
models: List[OllamaModel]
|
models: List[OllamaModel]
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_request_body(request: Request, model_class: Type[BaseModel]) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Parse request body based on Content-Type header.
|
||||||
|
Supports both application/json and application/octet-stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI Request object
|
||||||
|
model_class: The Pydantic model class to parse the request into
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the provided model_class
|
||||||
|
"""
|
||||||
|
content_type = request.headers.get("content-type", "").lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if content_type.startswith("application/json"):
|
||||||
|
# FastAPI already handles JSON parsing for us
|
||||||
|
body = await request.json()
|
||||||
|
elif content_type.startswith("application/octet-stream"):
|
||||||
|
# Manually parse octet-stream as JSON
|
||||||
|
body_bytes = await request.body()
|
||||||
|
body = json.loads(body_bytes.decode('utf-8'))
|
||||||
|
else:
|
||||||
|
# Try to parse as JSON for any other content type
|
||||||
|
body_bytes = await request.body()
|
||||||
|
body = json.loads(body_bytes.decode('utf-8'))
|
||||||
|
|
||||||
|
# Create an instance of the model
|
||||||
|
return model_class(**body)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid JSON in request body"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Error parsing request body: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def estimate_tokens(text: str) -> int:
|
def estimate_tokens(text: str) -> int:
|
||||||
"""Estimate the number of tokens in text using tiktoken"""
|
"""Estimate the number of tokens in text using tiktoken"""
|
||||||
tokens = TiktokenTokenizer().encode(text)
|
tokens = TiktokenTokenizer().encode(text)
|
||||||
|
|
@ -197,13 +238,17 @@ class OllamaAPI:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.router.post("/generate", dependencies=[Depends(combined_auth)])
|
@self.router.post("/generate", dependencies=[Depends(combined_auth)], include_in_schema=True)
|
||||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
async def generate(raw_request: Request):
|
||||||
"""Handle generate completion requests acting as an Ollama model
|
"""Handle generate completion requests acting as an Ollama model
|
||||||
For compatibility purpose, the request is not processed by LightRAG,
|
For compatibility purpose, the request is not processed by LightRAG,
|
||||||
and will be handled by underlying LLM model.
|
and will be handled by underlying LLM model.
|
||||||
|
Supports both application/json and application/octet-stream Content-Types.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Parse the request body manually
|
||||||
|
request = await parse_request_body(raw_request, OllamaGenerateRequest)
|
||||||
|
|
||||||
query = request.prompt
|
query = request.prompt
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
prompt_tokens = estimate_tokens(query)
|
prompt_tokens = estimate_tokens(query)
|
||||||
|
|
@ -363,13 +408,17 @@ class OllamaAPI:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@self.router.post("/chat", dependencies=[Depends(combined_auth)])
|
@self.router.post("/chat", dependencies=[Depends(combined_auth)], include_in_schema=True)
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request):
|
||||||
"""Process chat completion requests acting as an Ollama model
|
"""Process chat completion requests acting as an Ollama model
|
||||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||||
|
Supports both application/json and application/octet-stream Content-Types.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Parse the request body manually
|
||||||
|
request = await parse_request_body(raw_request, OllamaChatRequest)
|
||||||
|
|
||||||
# Get all messages
|
# Get all messages
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if not messages:
|
if not messages:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue