* feat: Implement multi-tenant architecture with tenant and knowledge base models - Added data models for tenants, knowledge bases, and related configurations. - Introduced role and permission management for users in the multi-tenant system. - Created a service layer for managing tenants and knowledge bases, including CRUD operations. - Developed a tenant-aware instance manager for LightRAG with caching and isolation features. - Added a migration script to transition existing workspace-based deployments to the new multi-tenant architecture. * chore: ignore lightrag/api/webui/assets/ directory * chore: stop tracking lightrag/api/webui/assets (ignore in .gitignore) * feat: Initialize LightRAG Multi-Tenant Stack with PostgreSQL - Added README.md for project overview, setup instructions, and architecture details. - Created docker-compose.yml to define services: PostgreSQL, Redis, LightRAG API, and Web UI. - Introduced env.example for environment variable configuration. - Implemented init-postgres.sql for PostgreSQL schema initialization with multi-tenant support. - Added reproduce_issue.py for testing default tenant access via API. * feat: Enhance TenantSelector and update related components for improved multi-tenant support * feat: Enhance testing capabilities and update documentation - Updated Makefile to include new test commands for various modes (compatibility, isolation, multi-tenant, security, coverage, and dry-run). - Modified API health check endpoint in Makefile to reflect new port configuration. - Updated QUICK_START.md and README.md to reflect changes in service URLs and ports. - Added environment variables for testing modes in env.example. - Introduced run_all_tests.sh script to automate testing across different modes. - Created conftest.py for pytest configuration, including database fixtures and mock services. - Implemented database helper functions for streamlined database operations in tests. - Added test collection hooks to skip tests based on the current MULTITENANT_MODE. * feat: Implement multi-tenant support with demo mode enabled by default - Added multi-tenant configuration to the environment and Docker setup. - Created pre-configured demo tenants (acme-corp and techstart) for testing. - Updated API endpoints to support tenant-specific data access. - Enhanced Makefile commands for better service management and database operations. - Introduced user-tenant membership system with role-based access control. - Added comprehensive documentation for multi-tenant setup and usage. - Fixed issues with document visibility in multi-tenant environments. - Implemented necessary database migrations for user memberships and legacy support. * feat(audit): Add final audit report for multi-tenant implementation - Documented overall assessment, architecture overview, test results, security findings, and recommendations. - Included detailed findings on critical security issues and architectural concerns. fix(security): Implement security fixes based on audit findings - Removed global RAG fallback and enforced strict tenant context. - Configured super-admin access and required user authentication for tenant access. - Cleared localStorage on logout and improved error handling in WebUI. chore(logs): Create task logs for audit and security fixes implementation - Documented actions, decisions, and next steps for both audit and security fixes. - Summarized test results and remaining recommendations. chore(scripts): Enhance development stack management scripts - Added scripts for cleaning, starting, and stopping the development stack. - Improved output messages and ensured graceful shutdown of services. feat(starter): Initialize PostgreSQL with AGE extension support - Created initialization scripts for PostgreSQL extensions including uuid-ossp, vector, and AGE. - Ensured successful installation and verification of extensions. * feat: Implement auto-select for first tenant and KB on initial load in WebUI - Removed WEBUI_INITIAL_STATE_FIX.md as the issue is resolved. - Added useTenantInitialization hook to automatically select the first available tenant and KB on app load. - Integrated the new hook into the Root component of the WebUI. - Updated RetrievalTesting component to ensure a KB is selected before allowing user interaction. - Created end-to-end tests for multi-tenant isolation and real service interactions. - Added scripts for starting, stopping, and cleaning the development stack. - Enhanced API and tenant routes to support tenant-specific pipeline status initialization. - Updated constants for backend URL to reflect the correct port. - Improved error handling and logging in various components. * feat: Add multi-tenant support with enhanced E2E testing scripts and client functionality * update client * Add integration and unit tests for multi-tenant API, models, security, and storage - Implement integration tests for tenant and knowledge base management endpoints in `test_tenant_api_routes.py`. - Create unit tests for tenant isolation, model validation, and role permissions in `test_tenant_models.py`. - Add security tests to enforce role-based permissions and context validation in `test_tenant_security.py`. - Develop tests for tenant-aware storage operations and context isolation in `test_tenant_storage_phase3.py`. * feat(e2e): Implement OpenAI model support and database reset functionality * Add comprehensive test suite for gpt-5-nano compatibility - Introduced tests for parameter normalization, embeddings, and entity extraction. - Implemented direct API testing for gpt-5-nano. - Validated .env configuration loading and OpenAI API connectivity. - Analyzed reasoning token overhead with various token limits. - Documented test procedures and expected outcomes in README files. - Ensured all tests pass for production readiness. * kg(postgres_impl): ensure AGE extension is loaded in session and configure graph initialization * dev: add hybrid dev helper scripts, Makefile, docker-compose.dev-db and local development docs * feat(dev): add dev helper scripts and local development documentation for hybrid setup * feat(multi-tenant): add detailed specifications and logs for multi-tenant improvements, including UX, backend handling, and ingestion pipeline * feat(migration): add generated tenant/kb columns, indexes, triggers; drop unused tables; update schema and docs * test(backward-compat): adapt tests to new StorageNameSpace/TenantService APIs (use concrete dummy storages) * chore: multi-tenant and UX updates — docs, webui, storage, tenant service adjustments * tests: stabilize integration tests + skip external services; fix multi-tenant API behavior and idempotency - gpt5_nano_compatibility: add pytest-asyncio markers, skip when OPENAI key missing, prevent module-level asyncio.run collection, add conftest - Ollama tests: add server availability check and skip markers; avoid pytest collection warnings by renaming helper classes - Graph storage tests: rename interactive test functions to avoid pytest collection - Document & Tenant routes: support external_ids for idempotency; ensure HTTPExceptions are re-raised - LightRAG core: support external_ids in apipeline_enqueue_documents and idempotent logic - Tests updated to match API changes (tenant routes & document routes) - Add logs and scripts for inspection and audit
757 lines
33 KiB
Python
757 lines
33 KiB
Python
from fastapi import APIRouter, HTTPException, Request
|
|
from pydantic import BaseModel
|
|
from typing import List, Dict, Any, Optional, Type
|
|
from lightrag.utils import logger
|
|
import time
|
|
import json
|
|
import re
|
|
from enum import Enum
|
|
from fastapi.responses import StreamingResponse
|
|
import asyncio
|
|
from ascii_colors import trace_exception
|
|
from lightrag import LightRAG, QueryParam
|
|
from lightrag.utils import TiktokenTokenizer
|
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
|
from lightrag.api.dependencies import get_tenant_context_optional
|
|
from lightrag.models.tenant import TenantContext
|
|
from lightrag.tenant_rag_manager import TenantRAGManager
|
|
from fastapi import Depends
|
|
|
|
|
|
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
|
class SearchMode(str, Enum):
|
|
naive = "naive"
|
|
local = "local"
|
|
global_ = "global"
|
|
hybrid = "hybrid"
|
|
mix = "mix"
|
|
bypass = "bypass"
|
|
context = "context"
|
|
|
|
|
|
class OllamaMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
images: Optional[List[str]] = None
|
|
|
|
|
|
class OllamaChatRequest(BaseModel):
|
|
model: str
|
|
messages: List[OllamaMessage]
|
|
stream: bool = True
|
|
options: Optional[Dict[str, Any]] = None
|
|
system: Optional[str] = None
|
|
|
|
|
|
class OllamaChatResponse(BaseModel):
|
|
model: str
|
|
created_at: str
|
|
message: OllamaMessage
|
|
done: bool
|
|
|
|
|
|
class OllamaGenerateRequest(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
system: Optional[str] = None
|
|
stream: bool = False
|
|
options: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class OllamaGenerateResponse(BaseModel):
|
|
model: str
|
|
created_at: str
|
|
response: str
|
|
done: bool
|
|
context: Optional[List[int]]
|
|
total_duration: Optional[int]
|
|
load_duration: Optional[int]
|
|
prompt_eval_count: Optional[int]
|
|
prompt_eval_duration: Optional[int]
|
|
eval_count: Optional[int]
|
|
eval_duration: Optional[int]
|
|
|
|
|
|
class OllamaVersionResponse(BaseModel):
|
|
version: str
|
|
|
|
|
|
class OllamaModelDetails(BaseModel):
|
|
parent_model: str
|
|
format: str
|
|
family: str
|
|
families: List[str]
|
|
parameter_size: str
|
|
quantization_level: str
|
|
|
|
|
|
class OllamaModel(BaseModel):
|
|
name: str
|
|
model: str
|
|
size: int
|
|
digest: str
|
|
modified_at: str
|
|
details: OllamaModelDetails
|
|
|
|
|
|
class OllamaTagResponse(BaseModel):
|
|
models: List[OllamaModel]
|
|
|
|
|
|
class OllamaRunningModelDetails(BaseModel):
|
|
parent_model: str
|
|
format: str
|
|
family: str
|
|
families: List[str]
|
|
parameter_size: str
|
|
quantization_level: str
|
|
|
|
|
|
class OllamaRunningModel(BaseModel):
|
|
name: str
|
|
model: str
|
|
size: int
|
|
digest: str
|
|
details: OllamaRunningModelDetails
|
|
expires_at: str
|
|
size_vram: int
|
|
|
|
|
|
class OllamaPsResponse(BaseModel):
|
|
models: List[OllamaRunningModel]
|
|
|
|
|
|
async def parse_request_body(
|
|
request: Request, model_class: Type[BaseModel]
|
|
) -> BaseModel:
|
|
"""
|
|
Parse request body based on Content-Type header.
|
|
Supports both application/json and application/octet-stream.
|
|
|
|
Args:
|
|
request: The FastAPI Request object
|
|
model_class: The Pydantic model class to parse the request into
|
|
|
|
Returns:
|
|
An instance of the provided model_class
|
|
"""
|
|
content_type = request.headers.get("content-type", "").lower()
|
|
|
|
try:
|
|
if content_type.startswith("application/json"):
|
|
# FastAPI already handles JSON parsing for us
|
|
body = await request.json()
|
|
elif content_type.startswith("application/octet-stream"):
|
|
# Manually parse octet-stream as JSON
|
|
body_bytes = await request.body()
|
|
body = json.loads(body_bytes.decode("utf-8"))
|
|
else:
|
|
# Try to parse as JSON for any other content type
|
|
body_bytes = await request.body()
|
|
body = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
# Create an instance of the model
|
|
return model_class(**body)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON in request body")
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Error parsing request body: {str(e)}"
|
|
)
|
|
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Estimate the number of tokens in text using tiktoken"""
|
|
tokens = TiktokenTokenizer().encode(text)
|
|
return len(tokens)
|
|
|
|
|
|
def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]:
|
|
"""Parse query prefix to determine search mode
|
|
Returns tuple of (cleaned_query, search_mode, only_need_context, user_prompt)
|
|
|
|
Examples:
|
|
- "/local[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.local, False, "use mermaid format for diagrams")
|
|
- "/[use mermaid format for diagrams] query string" -> (cleaned_query, SearchMode.hybrid, False, "use mermaid format for diagrams")
|
|
- "/local query string" -> (cleaned_query, SearchMode.local, False, None)
|
|
"""
|
|
# Initialize user_prompt as None
|
|
user_prompt = None
|
|
|
|
# First check if there's a bracket format for user prompt
|
|
bracket_pattern = r"^/([a-z]*)\[(.*?)\](.*)"
|
|
bracket_match = re.match(bracket_pattern, query)
|
|
|
|
if bracket_match:
|
|
mode_prefix = bracket_match.group(1)
|
|
user_prompt = bracket_match.group(2)
|
|
remaining_query = bracket_match.group(3).lstrip()
|
|
|
|
# Reconstruct query, removing the bracket part
|
|
query = f"/{mode_prefix} {remaining_query}".strip()
|
|
|
|
# Unified handling of mode and only_need_context determination
|
|
mode_map = {
|
|
"/local ": (SearchMode.local, False),
|
|
"/global ": (
|
|
SearchMode.global_,
|
|
False,
|
|
), # global_ is used because 'global' is a Python keyword
|
|
"/naive ": (SearchMode.naive, False),
|
|
"/hybrid ": (SearchMode.hybrid, False),
|
|
"/mix ": (SearchMode.mix, False),
|
|
"/bypass ": (SearchMode.bypass, False),
|
|
"/context": (
|
|
SearchMode.mix,
|
|
True,
|
|
),
|
|
"/localcontext": (SearchMode.local, True),
|
|
"/globalcontext": (SearchMode.global_, True),
|
|
"/hybridcontext": (SearchMode.hybrid, True),
|
|
"/naivecontext": (SearchMode.naive, True),
|
|
"/mixcontext": (SearchMode.mix, True),
|
|
}
|
|
|
|
for prefix, (mode, only_need_context) in mode_map.items():
|
|
if query.startswith(prefix):
|
|
# After removing prefix and leading spaces
|
|
cleaned_query = query[len(prefix) :].lstrip()
|
|
return cleaned_query, mode, only_need_context, user_prompt
|
|
|
|
return query, SearchMode.mix, False, user_prompt
|
|
|
|
|
|
class OllamaAPI:
|
|
def __init__(
|
|
self,
|
|
rag: LightRAG,
|
|
top_k: int = 60,
|
|
api_key: Optional[str] = None,
|
|
rag_manager: Optional[TenantRAGManager] = None
|
|
):
|
|
self.rag = rag
|
|
self.rag_manager = rag_manager
|
|
self.ollama_server_infos = rag.ollama_server_infos
|
|
self.top_k = top_k
|
|
self.api_key = api_key
|
|
self.router = APIRouter(tags=["ollama"])
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
# Create combined auth dependency for Ollama API routes
|
|
combined_auth = get_combined_auth_dependency(self.api_key)
|
|
|
|
# Create get_tenant_rag dependency for tenant-aware operations
|
|
async def get_tenant_rag(
|
|
tenant_context: Optional[TenantContext] = Depends(get_tenant_context_optional)
|
|
) -> LightRAG:
|
|
"""Dependency to get tenant-specific RAG instance for Ollama operations"""
|
|
if self.rag_manager and tenant_context and tenant_context.tenant_id and tenant_context.kb_id:
|
|
return await self.rag_manager.get_rag_instance(
|
|
tenant_context.tenant_id,
|
|
tenant_context.kb_id,
|
|
tenant_context.user_id
|
|
)
|
|
return self.rag
|
|
|
|
@self.router.get("/version", dependencies=[Depends(combined_auth)])
|
|
async def get_version():
|
|
"""Get Ollama version information"""
|
|
return OllamaVersionResponse(version="0.9.3")
|
|
|
|
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
|
async def get_tags():
|
|
"""Return available models acting as an Ollama server"""
|
|
return OllamaTagResponse(
|
|
models=[
|
|
{
|
|
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
|
"details": {
|
|
"parent_model": "",
|
|
"format": "gguf",
|
|
"family": self.ollama_server_infos.LIGHTRAG_NAME,
|
|
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
|
|
"parameter_size": "13B",
|
|
"quantization_level": "Q4_0",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
@self.router.get("/ps", dependencies=[Depends(combined_auth)])
|
|
async def get_running_models():
|
|
"""List Running Models - returns currently running models"""
|
|
return OllamaPsResponse(
|
|
models=[
|
|
{
|
|
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
|
"details": {
|
|
"parent_model": "",
|
|
"format": "gguf",
|
|
"family": "llama",
|
|
"families": ["llama"],
|
|
"parameter_size": "7.2B",
|
|
"quantization_level": "Q4_0",
|
|
},
|
|
"expires_at": "2050-12-31T14:38:31.83753-07:00",
|
|
"size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
|
|
}
|
|
]
|
|
)
|
|
|
|
@self.router.post(
|
|
"/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
|
|
)
|
|
async def generate(
|
|
raw_request: Request,
|
|
tenant_rag: LightRAG = Depends(get_tenant_rag)
|
|
):
|
|
"""Handle generate completion requests acting as an Ollama model (tenant-scoped).
|
|
For compatibility purpose, the request is not processed by LightRAG,
|
|
and will be handled by underlying LLM model.
|
|
Supports both application/json and application/octet-stream Content-Types.
|
|
"""
|
|
try:
|
|
# Parse the request body manually
|
|
request = await parse_request_body(raw_request, OllamaGenerateRequest)
|
|
|
|
query = request.prompt
|
|
start_time = time.time_ns()
|
|
prompt_tokens = estimate_tokens(query)
|
|
|
|
if request.system:
|
|
tenant_rag.llm_model_kwargs["system_prompt"] = request.system
|
|
|
|
if request.stream:
|
|
response = await tenant_rag.llm_model_func(
|
|
query, stream=True, **tenant_rag.llm_model_kwargs
|
|
)
|
|
|
|
async def stream_generator():
|
|
try:
|
|
first_chunk_time = None
|
|
last_chunk_time = time.time_ns()
|
|
total_response = ""
|
|
|
|
# Ensure response is an async generator
|
|
if isinstance(response, str):
|
|
# If it's a string, send in two parts
|
|
first_chunk_time = start_time
|
|
last_chunk_time = time.time_ns()
|
|
total_response = response
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": response,
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": "",
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"context": [],
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
else:
|
|
try:
|
|
async for chunk in response:
|
|
if chunk:
|
|
if first_chunk_time is None:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
total_response += chunk
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": chunk,
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
except (asyncio.CancelledError, Exception) as e:
|
|
error_msg = str(e)
|
|
if isinstance(e, asyncio.CancelledError):
|
|
error_msg = "Stream was cancelled by server"
|
|
else:
|
|
error_msg = f"Provider error: {error_msg}"
|
|
|
|
logger.error(f"Stream error: {error_msg}")
|
|
|
|
# Send error message to client
|
|
error_data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": f"\n\nError: {error_msg}",
|
|
"error": f"\n\nError: {error_msg}",
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
|
|
|
# Send final message to close the stream
|
|
final_data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": "",
|
|
"done": True,
|
|
}
|
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
|
return
|
|
if first_chunk_time is None:
|
|
first_chunk_time = start_time
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": "",
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"context": [],
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
return
|
|
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": "application/x-ndjson",
|
|
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
|
|
},
|
|
)
|
|
else:
|
|
first_chunk_time = time.time_ns()
|
|
response_text = await tenant_rag.llm_model_func(
|
|
query, stream=False, **tenant_rag.llm_model_kwargs
|
|
)
|
|
last_chunk_time = time.time_ns()
|
|
|
|
if not response_text:
|
|
response_text = "No response generated"
|
|
|
|
completion_tokens = estimate_tokens(str(response_text))
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
return {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"response": str(response_text),
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"context": [],
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@self.router.post(
|
|
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
|
|
)
|
|
async def chat(
|
|
raw_request: Request,
|
|
tenant_rag: LightRAG = Depends(get_tenant_rag)
|
|
):
|
|
"""Process chat completion requests by acting as an Ollama model (tenant-scoped).
|
|
Routes user queries through LightRAG by selecting query mode based on query prefix.
|
|
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
|
Supports both application/json and application/octet-stream Content-Types.
|
|
"""
|
|
try:
|
|
# Parse the request body manually
|
|
request = await parse_request_body(raw_request, OllamaChatRequest)
|
|
|
|
# Get all messages
|
|
messages = request.messages
|
|
if not messages:
|
|
raise HTTPException(status_code=400, detail="No messages provided")
|
|
|
|
# Get the last message as query and previous messages as history
|
|
query = messages[-1].content
|
|
# Convert OllamaMessage objects to dictionaries
|
|
conversation_history = [
|
|
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
|
]
|
|
|
|
# Check for query prefix
|
|
cleaned_query, mode, only_need_context, user_prompt = parse_query_mode(
|
|
query
|
|
)
|
|
|
|
start_time = time.time_ns()
|
|
prompt_tokens = estimate_tokens(cleaned_query)
|
|
|
|
param_dict = {
|
|
"mode": mode.value,
|
|
"stream": request.stream,
|
|
"only_need_context": only_need_context,
|
|
"conversation_history": conversation_history,
|
|
"top_k": self.top_k,
|
|
}
|
|
|
|
# Add user_prompt to param_dict
|
|
if user_prompt is not None:
|
|
param_dict["user_prompt"] = user_prompt
|
|
|
|
query_param = QueryParam(**param_dict)
|
|
|
|
if request.stream:
|
|
# Determine if the request is prefix with "/bypass"
|
|
if mode == SearchMode.bypass:
|
|
if request.system:
|
|
tenant_rag.llm_model_kwargs["system_prompt"] = request.system
|
|
response = await tenant_rag.llm_model_func(
|
|
cleaned_query,
|
|
stream=True,
|
|
history_messages=conversation_history,
|
|
**tenant_rag.llm_model_kwargs,
|
|
)
|
|
else:
|
|
response = await tenant_rag.aquery(
|
|
cleaned_query, param=query_param
|
|
)
|
|
|
|
async def stream_generator():
|
|
try:
|
|
first_chunk_time = None
|
|
last_chunk_time = time.time_ns()
|
|
total_response = ""
|
|
|
|
# Ensure response is an async generator
|
|
if isinstance(response, str):
|
|
# If it's a string, send in two parts
|
|
first_chunk_time = start_time
|
|
last_chunk_time = time.time_ns()
|
|
total_response = response
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response,
|
|
"images": None,
|
|
},
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"images": None,
|
|
},
|
|
"done_reason": "stop",
|
|
"done": True,
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
else:
|
|
try:
|
|
async for chunk in response:
|
|
if chunk:
|
|
if first_chunk_time is None:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
total_response += chunk
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": chunk,
|
|
"images": None,
|
|
},
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
except (asyncio.CancelledError, Exception) as e:
|
|
error_msg = str(e)
|
|
if isinstance(e, asyncio.CancelledError):
|
|
error_msg = "Stream was cancelled by server"
|
|
else:
|
|
error_msg = f"Provider error: {error_msg}"
|
|
|
|
logger.error(f"Stream error: {error_msg}")
|
|
|
|
# Send error message to client
|
|
error_data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": f"\n\nError: {error_msg}",
|
|
"images": None,
|
|
},
|
|
"error": f"\n\nError: {error_msg}",
|
|
"done": False,
|
|
}
|
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
|
|
|
# Send final message to close the stream
|
|
final_data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"images": None,
|
|
},
|
|
"done": True,
|
|
}
|
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
|
return
|
|
|
|
if first_chunk_time is None:
|
|
first_chunk_time = start_time
|
|
completion_tokens = estimate_tokens(total_response)
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
data = {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"images": None,
|
|
},
|
|
"done_reason": "stop",
|
|
"done": True,
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise
|
|
|
|
return StreamingResponse(
|
|
stream_generator(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Content-Type": "application/x-ndjson",
|
|
"X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy
|
|
},
|
|
)
|
|
else:
|
|
first_chunk_time = time.time_ns()
|
|
|
|
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
|
match_result = re.search(
|
|
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
|
)
|
|
if match_result or mode == SearchMode.bypass:
|
|
if request.system:
|
|
tenant_rag.llm_model_kwargs["system_prompt"] = request.system
|
|
|
|
response_text = await tenant_rag.llm_model_func(
|
|
cleaned_query,
|
|
stream=False,
|
|
history_messages=conversation_history,
|
|
**tenant_rag.llm_model_kwargs,
|
|
)
|
|
else:
|
|
response_text = await tenant_rag.aquery(
|
|
cleaned_query, param=query_param
|
|
)
|
|
|
|
last_chunk_time = time.time_ns()
|
|
|
|
if not response_text:
|
|
response_text = "No response generated"
|
|
|
|
completion_tokens = estimate_tokens(str(response_text))
|
|
total_time = last_chunk_time - start_time
|
|
prompt_eval_time = first_chunk_time - start_time
|
|
eval_time = last_chunk_time - first_chunk_time
|
|
|
|
return {
|
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": str(response_text),
|
|
"images": None,
|
|
},
|
|
"done_reason": "stop",
|
|
"done": True,
|
|
"total_duration": total_time,
|
|
"load_duration": 0,
|
|
"prompt_eval_count": prompt_tokens,
|
|
"prompt_eval_duration": prompt_eval_time,
|
|
"eval_count": completion_tokens,
|
|
"eval_duration": eval_time,
|
|
}
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|