Add S3 storage client and API routes for document management: - Implement s3_routes.py with file upload, download, delete endpoints - Enhance s3_client.py with improved error handling and operations - Add S3 browser UI component with file viewing and management - Implement FileViewer and PDFViewer components for storage preview - Add Resizable and Sheet UI components for layout control Update backend infrastructure: - Add bulk operations and parameterized queries to postgres_impl.py - Enhance document routes with improved type hints - Update API server registration for new S3 routes - Refine upload routes and utility functions Modernize web UI: - Integrate S3 browser into main application layout - Update localization files for storage UI strings - Add storage settings to application configuration - Sync package dependencies and lock files Remove obsolete reproduction script: - Delete reproduce_citation.py (replaced by test suite) Update configuration: - Enhance pyrightconfig.json for stricter type checking
731 lines
32 KiB
Python
731 lines
32 KiB
Python
import asyncio
|
|
import json
|
|
import re
|
|
import time
|
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
from enum import Enum
|
|
from typing import Any, TypeVar, cast
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from lightrag import LightRAG, QueryParam
|
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
|
from lightrag.base import OllamaServerInfos
|
|
from lightrag.constants import DEFAULT_TOP_K
|
|
from lightrag.utils import TiktokenTokenizer, logger
|
|
|
|
|
|
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
|
class SearchMode(str, Enum):
|
|
naive = 'naive'
|
|
local = 'local'
|
|
global_ = 'global'
|
|
hybrid = 'hybrid'
|
|
mix = 'mix'
|
|
bypass = 'bypass'
|
|
context = 'context'
|
|
|
|
|
|
class OllamaMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
images: list[str] | None = None
|
|
|
|
|
|
class OllamaChatRequest(BaseModel):
|
|
model: str
|
|
messages: list[OllamaMessage]
|
|
stream: bool = True
|
|
options: dict[str, Any] | None = None
|
|
system: str | None = None
|
|
|
|
|
|
class OllamaChatResponse(BaseModel):
|
|
model: str
|
|
created_at: str
|
|
message: OllamaMessage
|
|
done: bool
|
|
|
|
|
|
class OllamaGenerateRequest(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
system: str | None = None
|
|
stream: bool = False
|
|
options: dict[str, Any] | None = None
|
|
|
|
|
|
class OllamaGenerateResponse(BaseModel):
|
|
model: str
|
|
created_at: str
|
|
response: str
|
|
done: bool
|
|
context: list[int] | None
|
|
total_duration: int | None
|
|
load_duration: int | None
|
|
prompt_eval_count: int | None
|
|
prompt_eval_duration: int | None
|
|
eval_count: int | None
|
|
eval_duration: int | None
|
|
|
|
|
|
class OllamaVersionResponse(BaseModel):
|
|
version: str
|
|
|
|
|
|
class OllamaModelDetails(BaseModel):
|
|
parent_model: str
|
|
format: str
|
|
family: str
|
|
families: list[str]
|
|
parameter_size: str
|
|
quantization_level: str
|
|
|
|
|
|
class OllamaModel(BaseModel):
|
|
name: str
|
|
model: str
|
|
size: int
|
|
digest: str
|
|
modified_at: str
|
|
details: OllamaModelDetails
|
|
|
|
|
|
class OllamaTagResponse(BaseModel):
|
|
models: list[OllamaModel]
|
|
|
|
|
|
class OllamaRunningModelDetails(BaseModel):
|
|
parent_model: str
|
|
format: str
|
|
family: str
|
|
families: list[str]
|
|
parameter_size: str
|
|
quantization_level: str
|
|
|
|
|
|
class OllamaRunningModel(BaseModel):
|
|
name: str
|
|
model: str
|
|
size: int
|
|
digest: str
|
|
details: OllamaRunningModelDetails
|
|
expires_at: str
|
|
size_vram: int
|
|
|
|
|
|
class OllamaPsResponse(BaseModel):
|
|
models: list[OllamaRunningModel]
|
|
|
|
|
|
T = TypeVar('T', bound=BaseModel)
|
|
|
|
|
|
async def parse_request_body(request: Request, model_class: type[T]) -> T:
|
|
"""
|
|
Parse request body based on Content-Type header.
|
|
Supports both application/json and application/octet-stream.
|
|
|
|
Args:
|
|
request: The FastAPI Request object
|
|
model_class: The Pydantic model class to parse the request into
|
|
|
|
Returns:
|
|
An instance of the provided model_class
|
|
"""
|
|
content_type = request.headers.get('content-type', '').lower()
|
|
|
|
try:
|
|
if content_type.startswith('application/json'):
|
|
# FastAPI already handles JSON parsing for us
|
|
body = await request.json()
|
|
elif content_type.startswith('application/octet-stream'):
|
|
# Manually parse octet-stream as JSON
|
|
body_bytes = await request.body()
|
|
body = json.loads(body_bytes.decode('utf-8'))
|
|
else:
|
|
# Try to parse as JSON for any other content type
|
|
body_bytes = await request.body()
|
|
body = json.loads(body_bytes.decode('utf-8'))
|
|
|
|
# Create an instance of the model
|
|
return model_class(**body)
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(status_code=400, detail='Invalid JSON in request body') from e
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f'Error parsing request body: {e!s}') from e
|
|
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Estimate the number of tokens in text using tiktoken"""
|
|
tokens = TiktokenTokenizer().encode(text)
|
|
return len(tokens)
|
|
|
|
|
|
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, str | None]:
|
|
"""Parse query prefix to determine search mode
|
|
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
|
|
|
|
Examples:
|
|
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
|
|
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
|
|
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
|
|
"""
|
|
# Initialize user_prompt as None
|
|
user_prompt = None
|
|
|
|
# First check if there's a bracket format for user prompt
|
|
bracket_pattern = r'^/([a-z]*)\[(.*?)\](.*)'
|
|
bracket_match = re.match(bracket_pattern, query)
|
|
|
|
if bracket_match:
|
|
mode_prefix = bracket_match.group(1)
|
|
user_prompt = bracket_match.group(2)
|
|
remaining_query = bracket_match.group(3).lstrip()
|
|
|
|
# Reconstruct query, removing the bracket part
|
|
query = f'/{mode_prefix} {remaining_query}'.strip()
|
|
|
|
# Unified handling of mode and only_need_context determination
|
|
mode_map = {
|
|
'/local ': (SearchMode.local, False),
|
|
'/global ': (
|
|
SearchMode.global_,
|
|
False,
|
|
), # global_ is used because 'global' is a Python keyword
|
|
'/naive ': (SearchMode.naive, False),
|
|
'/hybrid ': (SearchMode.hybrid, False),
|
|
'/mix ': (SearchMode.mix, False),
|
|
'/bypass ': (SearchMode.bypass, False),
|
|
'/context': (
|
|
SearchMode.mix,
|
|
True,
|
|
),
|
|
'/localcontext': (SearchMode.local, True),
|
|
'/globalcontext': (SearchMode.global_, True),
|
|
'/hybridcontext': (SearchMode.hybrid, True),
|
|
'/naivecontext': (SearchMode.naive, True),
|
|
'/mixcontext': (SearchMode.mix, True),
|
|
}
|
|
|
|
for prefix, (mode, only_need_context) in mode_map.items():
|
|
if query.startswith(prefix):
|
|
# After removing prefix and leading spaces
|
|
cleaned_query = query[len(prefix) :].lstrip()
|
|
return cleaned_query, mode, only_need_context, user_prompt
|
|
|
|
return query, SearchMode.mix, False, user_prompt
|
|
|
|
|
|
class OllamaAPI:
|
|
def __init__(self, rag: LightRAG, top_k: int = DEFAULT_TOP_K, api_key: str | None = None):
|
|
self.rag = rag
|
|
# Ensure server infos are always present for typing and runtime safety
|
|
self.ollama_server_infos: OllamaServerInfos = rag.ollama_server_infos or OllamaServerInfos()
|
|
self.top_k = top_k
|
|
self.api_key = api_key
|
|
self.router = APIRouter(tags=['ollama'])
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
# Create combined auth dependency for Ollama API routes
|
|
combined_auth = get_combined_auth_dependency(self.api_key)
|
|
|
|
@self.router.get('/version', dependencies=[Depends(combined_auth)])
|
|
async def get_version():
|
|
"""Get Ollama version information"""
|
|
return OllamaVersionResponse(version='0.9.3')
|
|
|
|
@self.router.get('/tags', dependencies=[Depends(combined_auth)])
|
|
async def get_tags():
|
|
"""Return available models acting as an Ollama server"""
|
|
return OllamaTagResponse(
|
|
models=[
|
|
OllamaModel(
|
|
name=self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
model=self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
modified_at=self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
size=self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
digest=self.ollama_server_infos.LIGHTRAG_DIGEST,
|
|
details=OllamaModelDetails(
|
|
parent_model='',
|
|
format='gguf',
|
|
family=self.ollama_server_infos.LIGHTRAG_NAME,
|
|
families=[self.ollama_server_infos.LIGHTRAG_NAME],
|
|
parameter_size='13B',
|
|
quantization_level='Q4_0',
|
|
),
|
|
)
|
|
]
|
|
)
|
|
|
|
@self.router.get('/ps', dependencies=[Depends(combined_auth)])
|
|
async def get_running_models():
|
|
"""List Running Models - returns currently running models"""
|
|
return OllamaPsResponse(
|
|
models=[
|
|
OllamaRunningModel(
|
|
name=self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
model=self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
size=self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
digest=self.ollama_server_infos.LIGHTRAG_DIGEST,
|
|
details=OllamaRunningModelDetails(
|
|
parent_model='',
|
|
format='gguf',
|
|
family='llama',
|
|
families=['llama'],
|
|
parameter_size='7.2B',
|
|
quantization_level='Q4_0',
|
|
),
|
|
expires_at='2050-12-31T14:38:31.83753-07:00',
|
|
size_vram=self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
)
|
|
]
|
|
)
|
|
|
|
@self.router.post('/generate', dependencies=[Depends(combined_auth)], include_in_schema=True)
|
|
async def generate(raw_request: Request):
|
|
"""Handle generate completion requests acting as an Ollama model
|
|
For compatibility purpose, the request is not processed by LightRAG,
|
|
and will be handled by underlying LLM model.
|
|
Supports both application/json and application/octet-stream Content-Types.
|
|
"""
|
|
try:
|
|
# Parse the request body manually
|
|
request = await parse_request_body(raw_request, OllamaGenerateRequest)
|
|
|
|
query = request.prompt
|
|
start_time = time.time_ns()
|
|
prompt_tokens = estimate_tokens(query)
|
|
|
|
if request.system:
|
|
self.rag.llm_model_kwargs['system_prompt'] = request.system
|
|
|
|
if request.stream:
|
|
llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func)
|
|
response = cast(
|
|
str | AsyncIterator[str],
|
|
await llm_model_func(query, stream=True, **self.rag.llm_model_kwargs),
|
|
)
|
|
|
|
async def stream_generator():
|
|
first_chunk_time = None
|
|
last_chunk_time = time.time_ns()
|
|
total_response = ''
|
|
|
|
# Ensure response is an async generator
|
|
if isinstance(response, str):
|
|
# If it's a string, send in two parts
|
|
first_chunk_time = start_time
|
|
last_chunk_time = time.time_ns()
|
|
total_response = response
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': response,
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': '',
|
|
'done': True,
|
|
'done_reason': 'stop',
|
|
'context': [],
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
else:
|
|
stream = cast(AsyncIterator[str], response)
|
|
try:
|
|
async for chunk in stream:
|
|
if chunk:
|
|
if first_chunk_time is None:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
total_response += chunk
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': chunk,
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
except (asyncio.CancelledError, Exception) as e:
|
|
error_msg = str(e)
|
|
if isinstance(e, asyncio.CancelledError):
|
|
error_msg = 'Stream was cancelled by server'
|
|
else:
|
|
error_msg = f'Provider error: {error_msg}'
|
|
|
|
logger.error(f'Stream error: {error_msg}')
|
|
|
|
# Send error message to client
|
|
error_data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': f'\n\nError: {error_msg}',
|
|
'error': f'\n\nError: {error_msg}',
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(error_data, ensure_ascii=False)}\n'
|
|
|
|
# Send final message to close the stream
|
|
final_data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': '',
|
|
'done': True,
|
|
}
|
|
yield f'{json.dumps(final_data, ensure_ascii=False)}\n'
|
|
return
|
|
if first_chunk_time is None:
|
|
first_chunk_time = start_time
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': '',
|
|
'done': True,
|
|
'done_reason': 'stop',
|
|
'context': [],
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
return
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type='application/x-ndjson',
|
|
headers={
|
|
'Cache-Control': 'no-cache',
|
|
'Connection': 'keep-alive',
|
|
'Content-Type': 'application/x-ndjson',
|
|
'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy
|
|
},
|
|
)
|
|
else:
|
|
first_chunk_time = time.time_ns()
|
|
llm_model_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func)
|
|
response_text = cast(str, await llm_model_func(query, stream=False, **self.rag.llm_model_kwargs))
|
|
last_chunk_time = time.time_ns()
|
|
|
|
if not response_text:
|
|
response_text = 'No response generated'
|
|
|
|
completion_tokens = estimate_tokens(str(response_text))
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
return {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'response': str(response_text),
|
|
'done': True,
|
|
'done_reason': 'stop',
|
|
'context': [],
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f'Ollama generate error: {e!s}', exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
@self.router.post('/chat', dependencies=[Depends(combined_auth)], include_in_schema=True)
|
|
async def chat(raw_request: Request):
|
|
"""Process chat completion requests by acting as an Ollama model.
|
|
Routes user queries through LightRAG by selecting query mode based on query prefix.
|
|
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
|
Supports both application/json and application/octet-stream Content-Types.
|
|
"""
|
|
try:
|
|
# Parse the request body manually
|
|
request = await parse_request_body(raw_request, OllamaChatRequest)
|
|
|
|
# Get all messages
|
|
messages = request.messages
|
|
if not messages:
|
|
raise HTTPException(status_code=400, detail='No messages provided')
|
|
|
|
# Validate that the last message is from a user
|
|
if messages[-1].role != 'user':
|
|
raise HTTPException(status_code=400, detail='Last message must be from user role')
|
|
|
|
# Get the last message as query and previous messages as history
|
|
query = messages[-1].content
|
|
# Convert OllamaMessage objects to dictionaries
|
|
conversation_history = [{'role': msg.role, 'content': msg.content} for msg in messages[:-1]]
|
|
|
|
# Check for query prefix
|
|
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(query)
|
|
|
|
start_time = time.time_ns()
|
|
prompt_tokens = estimate_tokens(cleaned_query)
|
|
|
|
param_dict = {
|
|
'mode': mode.value,
|
|
'stream': request.stream,
|
|
'only_need_context': only_need_context,
|
|
'conversation_history': conversation_history,
|
|
'top_k': self.top_k,
|
|
}
|
|
|
|
# Add user_prompt to param_dict
|
|
if user_prompt is not None:
|
|
param_dict['user_prompt'] = user_prompt
|
|
|
|
# Create QueryParam object from the parsed parameters
|
|
query_param = QueryParam(**cast(dict[str, Any], param_dict))
|
|
|
|
# Execute query using the configured RAG instance
|
|
# If stream is enabled, return StreamingResponse
|
|
|
|
if request.stream:
|
|
# Determine if the request is prefix with "/bypass"
|
|
if mode == SearchMode.bypass:
|
|
if request.system:
|
|
self.rag.llm_model_kwargs['system_prompt'] = request.system
|
|
llm_model_func = cast(
|
|
Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func
|
|
)
|
|
response = cast(
|
|
str | AsyncIterator[str],
|
|
await llm_model_func(
|
|
cleaned_query,
|
|
stream=True,
|
|
history_messages=conversation_history,
|
|
**self.rag.llm_model_kwargs,
|
|
),
|
|
)
|
|
else:
|
|
aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery)
|
|
response = cast(str | AsyncIterator[str], await aquery_func(cleaned_query, param=query_param))
|
|
|
|
async def stream_generator():
|
|
first_chunk_time = None
|
|
last_chunk_time = time.time_ns()
|
|
total_response = ''
|
|
|
|
# Ensure response is an async generator
|
|
if isinstance(response, str):
|
|
# If it's a string, send in two parts
|
|
first_chunk_time = start_time
|
|
last_chunk_time = time.time_ns()
|
|
total_response = response
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': response,
|
|
'images': None,
|
|
},
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': '',
|
|
'images': None,
|
|
},
|
|
'done_reason': 'stop',
|
|
'done': True,
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
else:
|
|
stream = cast(AsyncIterator[str], response)
|
|
try:
|
|
async for chunk in stream:
|
|
if chunk:
|
|
if first_chunk_time is None:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
total_response += chunk
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': chunk,
|
|
'images': None,
|
|
},
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
except (asyncio.CancelledError, Exception) as e:
|
|
error_msg = str(e)
|
|
if isinstance(e, asyncio.CancelledError):
|
|
error_msg = 'Stream was cancelled by server'
|
|
else:
|
|
error_msg = f'Provider error: {error_msg}'
|
|
|
|
logger.error(f'Stream error: {error_msg}')
|
|
|
|
# Send error message to client
|
|
error_data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': f'\n\nError: {error_msg}',
|
|
'images': None,
|
|
},
|
|
'error': f'\n\nError: {error_msg}',
|
|
'done': False,
|
|
}
|
|
yield f'{json.dumps(error_data, ensure_ascii=False)}\n'
|
|
|
|
# Send final message to close the stream
|
|
final_data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': '',
|
|
'images': None,
|
|
},
|
|
'done': True,
|
|
}
|
|
yield f'{json.dumps(final_data, ensure_ascii=False)}\n'
|
|
return
|
|
|
|
if first_chunk_time is None:
|
|
first_chunk_time = start_time
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': '',
|
|
'images': None,
|
|
},
|
|
'done_reason': 'stop',
|
|
'done': True,
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
yield f'{json.dumps(data, ensure_ascii=False)}\n'
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type='application/x-ndjson',
|
|
headers={
|
|
'Cache-Control': 'no-cache',
|
|
'Connection': 'keep-alive',
|
|
'Content-Type': 'application/x-ndjson',
|
|
'X-Accel-Buffering': 'no', # Ensure proper handling of streaming responses in Nginx proxy
|
|
},
|
|
)
|
|
else:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
|
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
|
|
if match_result or mode == SearchMode.bypass:
|
|
if request.system:
|
|
self.rag.llm_model_kwargs['system_prompt'] = request.system
|
|
|
|
llm_model_func = cast(
|
|
Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.llm_model_func
|
|
)
|
|
response_text = cast(
|
|
str,
|
|
await llm_model_func(
|
|
cleaned_query,
|
|
stream=False,
|
|
history_messages=conversation_history,
|
|
**self.rag.llm_model_kwargs,
|
|
),
|
|
)
|
|
else:
|
|
aquery_func = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], self.rag.aquery)
|
|
response_text = cast(str, await aquery_func(cleaned_query, param=query_param))
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
if not response_text:
|
|
response_text = 'No response generated'
|
|
|
|
completion_tokens = estimate_tokens(str(response_text))
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
return {
|
|
'model': self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
'created_at': self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': str(response_text),
|
|
'images': None,
|
|
},
|
|
'done_reason': 'stop',
|
|
'done': True,
|
|
'total_duration': total_time,
|
|
'load_duration': 0,
|
|
'prompt_eval_count': prompt_tokens,
|
|
'prompt_eval_duration': prompt_eval_time,
|
|
'eval_count': completion_tokens,
|
|
'eval_duration': eval_time,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f'Ollama chat error: {e!s}', exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|