feat: Decouple user management from history service by removing the User model, using string user_ids, and renaming history tables.
This commit is contained in:
parent
854dc67c12
commit
d924577cc5
13 changed files with 1520 additions and 2067 deletions
|
|
@ -15,6 +15,7 @@ load_dotenv(dotenv_path=".env", override=False)
|
|||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # Username
|
||||
user_id: str # User ID
|
||||
exp: datetime # Expiration time
|
||||
role: str = "user" # User role, default is regular user
|
||||
metadata: dict = {} # Additional metadata
|
||||
|
|
@ -30,8 +31,13 @@ class AuthHandler:
|
|||
auth_accounts = global_args.auth_accounts
|
||||
if auth_accounts:
|
||||
for account in auth_accounts.split(","):
|
||||
username, password = account.split(":", 1)
|
||||
self.accounts[username] = password
|
||||
parts = account.split(":")
|
||||
if len(parts) == 3:
|
||||
username, password, user_id = parts
|
||||
else:
|
||||
username, password = parts
|
||||
user_id = username # Default user_id to username if not provided
|
||||
self.accounts[username] = {"password": password, "user_id": user_id}
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
|
|
@ -63,9 +69,14 @@ class AuthHandler:
|
|||
|
||||
expire = datetime.utcnow() + timedelta(hours=expire_hours)
|
||||
|
||||
# Get user_id from accounts or use username
|
||||
user_id = username
|
||||
if username in self.accounts and isinstance(self.accounts[username], dict):
|
||||
user_id = self.accounts[username].get("user_id", username)
|
||||
|
||||
# Create payload
|
||||
payload = TokenPayload(
|
||||
sub=username, exp=expire, role=role, metadata=metadata or {}
|
||||
sub=username, user_id=user_id, exp=expire, role=role, metadata=metadata or {}
|
||||
)
|
||||
|
||||
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
|
||||
|
|
@ -96,6 +107,7 @@ class AuthHandler:
|
|||
# Return complete payload instead of just username
|
||||
return {
|
||||
"username": payload["sub"],
|
||||
"user_id": payload.get("user_id", payload["sub"]),
|
||||
"role": payload.get("role", "user"),
|
||||
"metadata": payload.get("metadata", {}),
|
||||
"exp": expire_time,
|
||||
|
|
|
|||
|
|
@ -1161,7 +1161,8 @@ def create_app(args):
|
|||
"webui_description": webui_description,
|
||||
}
|
||||
username = form_data.username
|
||||
if auth_handler.accounts.get(username) != form_data.password:
|
||||
account = auth_handler.accounts.get(username)
|
||||
if not account or account["password"] != form_data.password:
|
||||
raise HTTPException(status_code=401, detail="Incorrect credentials")
|
||||
|
||||
# Regular user login
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Security, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
import sys
|
||||
import os
|
||||
from lightrag.api.auth import auth_handler
|
||||
|
||||
# Ensure service module is in path (similar to query_routes.py)
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
|
|
@ -17,7 +15,6 @@ try:
|
|||
from app.core.database import get_db
|
||||
from app.services.history_manager import HistoryManager
|
||||
from app.models.schemas import SessionResponse, SessionCreate, ChatMessageResponse
|
||||
from app.models.models import User
|
||||
except ImportError:
|
||||
# Fallback if service not found (shouldn't happen if setup is correct)
|
||||
get_db = None
|
||||
|
|
@ -25,79 +22,45 @@ except ImportError:
|
|||
SessionResponse = None
|
||||
SessionCreate = None
|
||||
ChatMessageResponse = None
|
||||
User = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login", auto_error=False)
|
||||
|
||||
def check_dependencies():
|
||||
if not HistoryManager:
|
||||
raise HTTPException(status_code=503, detail="History service not available")
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Security(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
check_dependencies()
|
||||
|
||||
if not token:
|
||||
# If no token provided, try to use default user if configured or allowed
|
||||
# For now, we'll return the default user for backward compatibility if needed,
|
||||
# but ideally we should require auth.
|
||||
# Let's check if we have a default user
|
||||
user = db.query(User).filter(User.username == "default_user").first()
|
||||
if not user:
|
||||
user = User(username="default_user", email="default@example.com")
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
try:
|
||||
user_data = auth_handler.validate_token(token)
|
||||
username = user_data["username"]
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
# Create user if not exists (auto-registration on first login)
|
||||
# In a real app you might want to fetch email from token metadata or require explicit registration
|
||||
user = User(username=username, email=f"{username}@example.com")
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
async def get_current_user_id(
|
||||
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
||||
) -> str:
|
||||
# Prefer X-User-ID, default to default_user
|
||||
uid = x_user_id
|
||||
if not uid:
|
||||
# Fallback to default user if no header provided (for backward compatibility or dev)
|
||||
# Or raise error if strict
|
||||
return "default_user"
|
||||
return uid
|
||||
|
||||
@router.get("/sessions", response_model=List[SessionResponse], tags=["History"])
|
||||
def list_sessions(
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user_id: str = Depends(get_current_user_id)
|
||||
):
|
||||
check_dependencies()
|
||||
manager = HistoryManager(db)
|
||||
sessions = manager.list_sessions(user_id=current_user.id, skip=skip, limit=limit)
|
||||
sessions = manager.list_sessions(user_id=current_user_id, skip=skip, limit=limit)
|
||||
return sessions
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse, tags=["History"])
|
||||
def create_session(
|
||||
session_in: SessionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user_id: str = Depends(get_current_user_id)
|
||||
):
|
||||
check_dependencies()
|
||||
manager = HistoryManager(db)
|
||||
return manager.create_session(user_id=current_user.id, title=session_in.title)
|
||||
return manager.create_session(user_id=current_user_id, title=session_in.title)
|
||||
|
||||
@router.get("/sessions/{session_id}/history", response_model=List[ChatMessageResponse], tags=["History"])
|
||||
def get_session_history(
|
||||
|
|
|
|||
|
|
@ -3,17 +3,19 @@ This module contains all query-related routes for the LightRAG API.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from lightrag.base import QueryParam
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from lightrag.base import QueryParam
|
||||
from lightrag.utils import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import sys
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
|
||||
# Add the project root to sys.path to allow importing 'service'
|
||||
# Assuming this file is at lightrag/api/routers/query_routes.py
|
||||
|
|
@ -28,13 +30,17 @@ if service_dir not in sys.path:
|
|||
try:
|
||||
from app.core.database import SessionLocal
|
||||
from app.services.history_manager import HistoryManager
|
||||
from app.models.models import User
|
||||
from app.models.schemas import ChatMessageResponse
|
||||
except ImportError as e:
|
||||
# Fallback or handle error if service module is not found
|
||||
print(f"Warning: Could not import service module. History logging will be disabled. Error: {e}")
|
||||
logger.error(f"Warning: Could not import service module. History logging will be disabled. Error: {e}")
|
||||
print(f"CRITICAL ERROR: Could not import service module: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
SessionLocal = None
|
||||
HistoryManager = None
|
||||
User = None
|
||||
QueryRequest = None
|
||||
QueryResponse = None
|
||||
ChatMessageResponse = None
|
||||
|
||||
|
||||
router = APIRouter(tags=["query"])
|
||||
|
|
@ -354,7 +360,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
},
|
||||
},
|
||||
)
|
||||
async def query_text(request: QueryRequest):
|
||||
async def query_text(
|
||||
request: QueryRequest,
|
||||
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
||||
):
|
||||
"""
|
||||
Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored.
|
||||
|
||||
|
|
@ -441,6 +450,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
param.stream = False
|
||||
|
||||
# Unified approach: always use aquery_llm for both cases
|
||||
start_time = time.time()
|
||||
result = await rag.aquery_llm(request.query, param=param)
|
||||
|
||||
# Extract LLM response and references from unified result
|
||||
|
|
@ -484,67 +494,81 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
final_response = QueryResponse(response=response_content, references=None)
|
||||
|
||||
# --- LOGGING START ---
|
||||
logger.info(f"DEBUG: SessionLocal={SessionLocal}, HistoryManager={HistoryManager}")
|
||||
if SessionLocal and HistoryManager:
|
||||
try:
|
||||
logger.info("DEBUG: Entering logging block")
|
||||
db = SessionLocal()
|
||||
manager = HistoryManager(db)
|
||||
|
||||
# 1. Get or Create User (Default)
|
||||
user = db.query(User).filter(User.username == "default_user").first()
|
||||
if not user:
|
||||
user = User(username="default_user", email="default@example.com")
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
# 1. Get User ID from Header (or default)
|
||||
current_user_id = x_user_id or "default_user"
|
||||
|
||||
# 2. Handle Session
|
||||
session_uuid = None
|
||||
if request.session_id:
|
||||
try:
|
||||
session_uuid = uuid.UUID(request.session_id)
|
||||
temp_uuid = uuid.UUID(request.session_id)
|
||||
# Verify session exists
|
||||
if not manager.get_session(session_uuid):
|
||||
# If provided ID doesn't exist, create it with that ID if possible or just create new
|
||||
# For simplicity, let's create a new one if it doesn't exist but we can't force ID easily with current manager
|
||||
# Let's just create a new session if not found or use the provided one if we trust it.
|
||||
# Actually, manager.create_session generates ID.
|
||||
# Let's just create a new session if the provided one is invalid/not found,
|
||||
# OR we can just create a new session if session_id is NOT provided.
|
||||
# If session_id IS provided, we assume it exists.
|
||||
pass
|
||||
if manager.get_session(temp_uuid):
|
||||
session_uuid = temp_uuid
|
||||
else:
|
||||
logger.warning(f"Session {request.session_id} not found. Creating new session.")
|
||||
except ValueError:
|
||||
pass
|
||||
logger.warning(f"Invalid session ID format: {request.session_id}")
|
||||
|
||||
if not session_uuid:
|
||||
# Create new session
|
||||
session = manager.create_session(user_id=user.id, title=request.query[:50])
|
||||
session = manager.create_session(user_id=current_user_id, title=request.query[:50])
|
||||
session_uuid = session.id
|
||||
|
||||
# Calculate processing time
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Calculate token counts
|
||||
try:
|
||||
import tiktoken
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
query_tokens = len(enc.encode(request.query))
|
||||
response_tokens = len(enc.encode(response_content))
|
||||
except ImportError:
|
||||
# Fallback approximation
|
||||
query_tokens = len(request.query) // 4
|
||||
response_tokens = len(response_content) // 4
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating tokens: {e}")
|
||||
query_tokens = len(request.query) // 4
|
||||
response_tokens = len(response_content) // 4
|
||||
|
||||
# 3. Log User Message
|
||||
manager.save_message(
|
||||
session_id=session_uuid,
|
||||
role="user",
|
||||
content=request.query
|
||||
content=request.query,
|
||||
token_count=query_tokens,
|
||||
processing_time=None # User message processing time is negligible/not applicable in this context
|
||||
)
|
||||
|
||||
# 4. Log Assistant Message
|
||||
ai_msg = manager.save_message(
|
||||
session_id=session_uuid,
|
||||
role="assistant",
|
||||
content=response_content
|
||||
content=response_content,
|
||||
token_count=response_tokens,
|
||||
processing_time=processing_time
|
||||
)
|
||||
|
||||
# 5. Log Citations
|
||||
if references:
|
||||
# Convert references to dict format expected by save_citations
|
||||
# references is a list of ReferenceItem (pydantic) or dicts?
|
||||
# In the code above: references = data.get("references", []) which are dicts.
|
||||
# Then enriched_references are also dicts.
|
||||
manager.save_citations(ai_msg.id, references)
|
||||
|
||||
db.close()
|
||||
except Exception as log_exc:
|
||||
print(f"Error logging history: {log_exc}")
|
||||
print(f"Error logging history: {log_exc}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"Error logging history: {log_exc}", exc_info=True)
|
||||
# Don't fail the request if logging fails
|
||||
# --- LOGGING END ---
|
||||
|
||||
|
|
@ -633,7 +657,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
},
|
||||
},
|
||||
)
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
async def query_text_stream(
|
||||
request: QueryRequest,
|
||||
x_user_id: Optional[str] = Header(None, alias="X-User-ID")
|
||||
):
|
||||
"""
|
||||
Advanced RAG query endpoint with flexible streaming response.
|
||||
|
||||
|
|
@ -848,26 +875,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
db = SessionLocal()
|
||||
manager = HistoryManager(db)
|
||||
|
||||
# 1. Get or Create User
|
||||
user = db.query(User).filter(User.username == "default_user").first()
|
||||
if not user:
|
||||
user = User(username="default_user", email="default@example.com")
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
# 1. Get User ID
|
||||
current_user_id = x_user_id or "default_user"
|
||||
|
||||
# 2. Handle Session
|
||||
session_uuid = None
|
||||
if request.session_id:
|
||||
try:
|
||||
session_uuid = uuid.UUID(request.session_id)
|
||||
temp_uuid = uuid.UUID(request.session_id)
|
||||
if manager.get_session(temp_uuid):
|
||||
session_uuid = temp_uuid
|
||||
else:
|
||||
logger.warning(f"Session {request.session_id} not found. Creating new session.")
|
||||
except ValueError:
|
||||
pass
|
||||
logger.warning(f"Invalid session ID format: {request.session_id}")
|
||||
|
||||
if not session_uuid or not manager.get_session(session_uuid):
|
||||
session = manager.create_session(user_id=user.id, title=request.query[:50])
|
||||
if not session_uuid:
|
||||
session = manager.create_session(user_id=current_user_id, title=request.query[:50])
|
||||
session_uuid = session.id
|
||||
|
||||
|
||||
# 3. Log User Message
|
||||
manager.save_message(
|
||||
session_id=session_uuid,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,17 @@ import { navigationService } from '@/services/navigation'
|
|||
import { useSettingsStore } from '@/stores/settings'
|
||||
import axios, { AxiosError } from 'axios'
|
||||
|
||||
const getUserIdFromToken = (token: string): string | null => {
|
||||
try {
|
||||
const parts = token.split('.')
|
||||
if (parts.length !== 3) return null
|
||||
const payload = JSON.parse(atob(parts[1]))
|
||||
return payload.user_id || payload.sub || null
|
||||
} catch (e) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
// Types
|
||||
export type LightragNodeType = {
|
||||
id: string
|
||||
|
|
@ -299,6 +310,10 @@ axiosInstance.interceptors.request.use((config) => {
|
|||
// Always include token if it exists, regardless of path
|
||||
if (token) {
|
||||
config.headers['Authorization'] = `Bearer ${token}`
|
||||
const userId = getUserIdFromToken(token)
|
||||
if (userId) {
|
||||
config.headers['X-User-ID'] = userId
|
||||
}
|
||||
}
|
||||
if (apiKey) {
|
||||
config.headers['X-API-Key'] = apiKey
|
||||
|
|
@ -418,6 +433,10 @@ export const queryTextStream = async (
|
|||
}
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`
|
||||
const userId = getUserIdFromToken(token)
|
||||
if (userId) {
|
||||
headers['X-User-ID'] = userId
|
||||
}
|
||||
}
|
||||
if (apiKey) {
|
||||
headers['X-API-Key'] = apiKey
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import { create } from 'zustand'
|
||||
import { createSelectors } from '@/lib/utils'
|
||||
import { checkHealth, LightragStatus } from '@/api/lightrag'
|
||||
import { useSettingsStore } from './settings'
|
||||
import { healthCheckInterval } from '@/lib/constants'
|
||||
import { createSelectors } from '@/lib/utils'
|
||||
import { create } from 'zustand'
|
||||
import { useSettingsStore } from './settings'
|
||||
|
||||
interface BackendState {
|
||||
health: boolean
|
||||
|
|
@ -26,18 +26,26 @@ interface BackendState {
|
|||
}
|
||||
|
||||
interface AuthState {
|
||||
isAuthenticated: boolean;
|
||||
isGuestMode: boolean; // Add guest mode flag
|
||||
coreVersion: string | null;
|
||||
apiVersion: string | null;
|
||||
username: string | null; // login username
|
||||
webuiTitle: string | null; // Custom title
|
||||
webuiDescription: string | null; // Title description
|
||||
isAuthenticated: boolean
|
||||
isGuestMode: boolean // Add guest mode flag
|
||||
coreVersion: string | null
|
||||
apiVersion: string | null
|
||||
username: string | null // login username
|
||||
userId: string | null // user id
|
||||
webuiTitle: string | null // Custom title
|
||||
webuiDescription: string | null // Title description
|
||||
|
||||
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void;
|
||||
logout: () => void;
|
||||
setVersion: (coreVersion: string | null, apiVersion: string | null) => void;
|
||||
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void;
|
||||
login: (
|
||||
token: string,
|
||||
isGuest?: boolean,
|
||||
coreVersion?: string | null,
|
||||
apiVersion?: string | null,
|
||||
webuiTitle?: string | null,
|
||||
webuiDescription?: string | null
|
||||
) => void
|
||||
logout: () => void
|
||||
setVersion: (coreVersion: string | null, apiVersion: string | null) => void
|
||||
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void
|
||||
}
|
||||
|
||||
const useBackendStateStoreBase = create<BackendState>()((set, get) => ({
|
||||
|
|
@ -56,18 +64,17 @@ const useBackendStateStoreBase = create<BackendState>()((set, get) => ({
|
|||
if (health.status === 'healthy') {
|
||||
// Update version information if health check returns it
|
||||
if (health.core_version || health.api_version) {
|
||||
useAuthStore.getState().setVersion(
|
||||
health.core_version || null,
|
||||
health.api_version || null
|
||||
);
|
||||
useAuthStore.getState().setVersion(health.core_version || null, health.api_version || null)
|
||||
}
|
||||
|
||||
// Update custom title information if health check returns it
|
||||
if ('webui_title' in health || 'webui_description' in health) {
|
||||
useAuthStore.getState().setCustomTitle(
|
||||
'webui_title' in health ? (health.webui_title ?? null) : null,
|
||||
'webui_description' in health ? (health.webui_description ?? null) : null
|
||||
);
|
||||
useAuthStore
|
||||
.getState()
|
||||
.setCustomTitle(
|
||||
'webui_title' in health ? (health.webui_title ?? null) : null,
|
||||
'webui_description' in health ? (health.webui_description ?? null) : null
|
||||
)
|
||||
}
|
||||
|
||||
// Extract and store backend max graph nodes limit
|
||||
|
|
@ -156,36 +163,51 @@ const useBackendState = createSelectors(useBackendStateStoreBase)
|
|||
|
||||
export { useBackendState }
|
||||
|
||||
const parseTokenPayload = (token: string): { sub?: string; role?: string } => {
|
||||
const parseTokenPayload = (token: string): { sub?: string; role?: string; user_id?: string } => {
|
||||
try {
|
||||
// JWT tokens are in the format: header.payload.signature
|
||||
const parts = token.split('.');
|
||||
if (parts.length !== 3) return {};
|
||||
const payload = JSON.parse(atob(parts[1]));
|
||||
return payload;
|
||||
const parts = token.split('.')
|
||||
if (parts.length !== 3) return {}
|
||||
const payload = JSON.parse(atob(parts[1]))
|
||||
return payload
|
||||
} catch (e) {
|
||||
console.error('Error parsing token payload:', e);
|
||||
return {};
|
||||
console.error('Error parsing token payload:', e)
|
||||
return {}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const getUsernameFromToken = (token: string): string | null => {
|
||||
const payload = parseTokenPayload(token);
|
||||
return payload.sub || null;
|
||||
};
|
||||
const payload = parseTokenPayload(token)
|
||||
return payload.sub || null
|
||||
}
|
||||
|
||||
const getUserIdFromToken = (token: string): string | null => {
|
||||
const payload = parseTokenPayload(token)
|
||||
return payload.user_id || payload.sub || null
|
||||
}
|
||||
|
||||
const isGuestToken = (token: string): boolean => {
|
||||
const payload = parseTokenPayload(token);
|
||||
return payload.role === 'guest';
|
||||
};
|
||||
const payload = parseTokenPayload(token)
|
||||
return payload.role === 'guest'
|
||||
}
|
||||
|
||||
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => {
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
const username = token ? getUsernameFromToken(token) : null;
|
||||
const initAuthState = (): {
|
||||
isAuthenticated: boolean
|
||||
isGuestMode: boolean
|
||||
coreVersion: string | null
|
||||
apiVersion: string | null
|
||||
username: string | null
|
||||
userId: string | null
|
||||
webuiTitle: string | null
|
||||
webuiDescription: string | null
|
||||
} => {
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN')
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION')
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION')
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE')
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION')
|
||||
const username = token ? getUsernameFromToken(token) : null
|
||||
const userId = token ? getUserIdFromToken(token) : null
|
||||
|
||||
if (!token) {
|
||||
return {
|
||||
|
|
@ -194,9 +216,10 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
|||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
username: null,
|
||||
userId: null,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
};
|
||||
webuiDescription: webuiDescription
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -205,14 +228,15 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
|||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
username: username,
|
||||
userId: userId,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
};
|
||||
};
|
||||
webuiDescription: webuiDescription
|
||||
}
|
||||
}
|
||||
|
||||
export const useAuthStore = create<AuthState>(set => {
|
||||
export const useAuthStore = create<AuthState>((set) => {
|
||||
// Get initial state from localStorage
|
||||
const initialState = initAuthState();
|
||||
const initialState = initAuthState()
|
||||
|
||||
return {
|
||||
isAuthenticated: initialState.isAuthenticated,
|
||||
|
|
@ -220,97 +244,109 @@ export const useAuthStore = create<AuthState>(set => {
|
|||
coreVersion: initialState.coreVersion,
|
||||
apiVersion: initialState.apiVersion,
|
||||
username: initialState.username,
|
||||
userId: initialState.userId,
|
||||
webuiTitle: initialState.webuiTitle,
|
||||
webuiDescription: initialState.webuiDescription,
|
||||
|
||||
login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => {
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', token);
|
||||
login: (
|
||||
token,
|
||||
isGuest = false,
|
||||
coreVersion = null,
|
||||
apiVersion = null,
|
||||
webuiTitle = null,
|
||||
webuiDescription = null
|
||||
) => {
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', token)
|
||||
|
||||
if (coreVersion) {
|
||||
localStorage.setItem('LIGHTRAG-CORE-VERSION', coreVersion);
|
||||
localStorage.setItem('LIGHTRAG-CORE-VERSION', coreVersion)
|
||||
}
|
||||
if (apiVersion) {
|
||||
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion);
|
||||
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion)
|
||||
}
|
||||
|
||||
if (webuiTitle) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle)
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE')
|
||||
}
|
||||
|
||||
if (webuiDescription) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription)
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION')
|
||||
}
|
||||
|
||||
const username = getUsernameFromToken(token);
|
||||
const username = getUsernameFromToken(token)
|
||||
const userId = getUserIdFromToken(token)
|
||||
set({
|
||||
isAuthenticated: true,
|
||||
isGuestMode: isGuest,
|
||||
username: username,
|
||||
userId: userId,
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
});
|
||||
webuiDescription: webuiDescription
|
||||
})
|
||||
},
|
||||
|
||||
logout: () => {
|
||||
localStorage.removeItem('LIGHTRAG-API-TOKEN');
|
||||
localStorage.removeItem('LIGHTRAG-API-TOKEN')
|
||||
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION')
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION')
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE')
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION')
|
||||
|
||||
set({
|
||||
isAuthenticated: false,
|
||||
isGuestMode: false,
|
||||
username: null,
|
||||
userId: null,
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
});
|
||||
webuiDescription: webuiDescription
|
||||
})
|
||||
},
|
||||
|
||||
setVersion: (coreVersion, apiVersion) => {
|
||||
// Update localStorage
|
||||
if (coreVersion) {
|
||||
localStorage.setItem('LIGHTRAG-CORE-VERSION', coreVersion);
|
||||
localStorage.setItem('LIGHTRAG-CORE-VERSION', coreVersion)
|
||||
}
|
||||
if (apiVersion) {
|
||||
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion);
|
||||
localStorage.setItem('LIGHTRAG-API-VERSION', apiVersion)
|
||||
}
|
||||
|
||||
// Update state
|
||||
set({
|
||||
coreVersion: coreVersion,
|
||||
apiVersion: apiVersion
|
||||
});
|
||||
})
|
||||
},
|
||||
|
||||
setCustomTitle: (webuiTitle, webuiDescription) => {
|
||||
// Update localStorage
|
||||
if (webuiTitle) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle);
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-TITLE', webuiTitle)
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE');
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-TITLE')
|
||||
}
|
||||
|
||||
if (webuiDescription) {
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription);
|
||||
localStorage.setItem('LIGHTRAG-WEBUI-DESCRIPTION', webuiDescription)
|
||||
} else {
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
localStorage.removeItem('LIGHTRAG-WEBUI-DESCRIPTION')
|
||||
}
|
||||
|
||||
// Update state
|
||||
set({
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription
|
||||
});
|
||||
})
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ dependencies = [
|
|||
"psycopg2-binary",
|
||||
"openai",
|
||||
"httpx",
|
||||
"redis",
|
||||
"pydantic-settings",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -99,6 +101,8 @@ api = [
|
|||
"python-docx>=0.8.11,<2.0.0", # DOCX processing
|
||||
"python-pptx>=0.6.21,<2.0.0", # PPTX processing
|
||||
"sqlalchemy>=2.0.0,<3.0.0",
|
||||
"pydantic_settings",
|
||||
"neo4j>=5.0.0,<7.0.0",
|
||||
]
|
||||
|
||||
# Advanced document processing engine (optional)
|
||||
|
|
|
|||
0
service/app/__init__.py
Normal file
0
service/app/__init__.py
Normal file
|
|
@ -32,17 +32,13 @@ def create_session(session_in: SessionCreate, db: Session = Depends(get_db)):
|
|||
# Actually, let's just create a user on the fly for this session if we don't have auth.
|
||||
|
||||
manager = HistoryManager(db)
|
||||
# Check if we have any user, if not create one.
|
||||
from app.models.models import User
|
||||
user = db.query(User).first()
|
||||
if not user:
|
||||
user = User(username="default_user", email="default@example.com")
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
# User logic removed
|
||||
# Using a fixed UUID for demonstration purposes. In a real application,
|
||||
# this would come from an authenticated user.
|
||||
fixed_user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
session = manager.create_session(
|
||||
user_id=user.id,
|
||||
user_id=fixed_user_id,
|
||||
title=session_in.title,
|
||||
rag_config=session_in.rag_config
|
||||
)
|
||||
|
|
@ -51,13 +47,13 @@ def create_session(session_in: SessionCreate, db: Session = Depends(get_db)):
|
|||
@router.get("/sessions", response_model=List[SessionResponse])
|
||||
def list_sessions(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
manager = HistoryManager(db)
|
||||
# Again, need user_id. Using the default user strategy.
|
||||
from app.models.models import User
|
||||
user = db.query(User).first()
|
||||
if not user:
|
||||
return []
|
||||
# User logic removed
|
||||
pass
|
||||
# Using a fixed UUID for demonstration purposes. In a real application,
|
||||
# this would come from an authenticated user.
|
||||
fixed_user_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
sessions = manager.list_sessions(user_id=user.id, skip=skip, limit=limit)
|
||||
sessions = manager.list_sessions(user_id=fixed_user_id, skip=skip, limit=limit)
|
||||
return sessions
|
||||
|
||||
@router.get("/sessions/{session_id}/history")
|
||||
|
|
|
|||
|
|
@ -5,37 +5,24 @@ from sqlalchemy.orm import relationship
|
|||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
username = Column(String(50), unique=True, nullable=False)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
full_name = Column(String(100), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
sessions = relationship("ChatSession", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_sessions"
|
||||
__tablename__ = "lightrag_chat_sessions_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=True)
|
||||
rag_config = Column(JSON, default={})
|
||||
summary = Column(Text, nullable=True)
|
||||
last_message_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
user = relationship("User", back_populates="sessions")
|
||||
messages = relationship("ChatMessage", back_populates="session", cascade="all, delete-orphan")
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_messages"
|
||||
__tablename__ = "lightrag_chat_messages_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False)
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("lightrag_chat_sessions_history.id", ondelete="CASCADE"), nullable=False)
|
||||
role = Column(String(20), nullable=False) # user, assistant, system
|
||||
content = Column(Text, nullable=False)
|
||||
token_count = Column(Integer, nullable=True)
|
||||
|
|
@ -46,10 +33,10 @@ class ChatMessage(Base):
|
|||
citations = relationship("MessageCitation", back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
class MessageCitation(Base):
|
||||
__tablename__ = "message_citations"
|
||||
__tablename__ = "lightrag_message_citations_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
message_id = Column(UUID(as_uuid=True), ForeignKey("chat_messages.id", ondelete="CASCADE"), nullable=False)
|
||||
message_id = Column(UUID(as_uuid=True), ForeignKey("lightrag_chat_messages_history.id", ondelete="CASCADE"), nullable=False)
|
||||
source_doc_id = Column(String(255), nullable=False, index=True)
|
||||
file_path = Column(Text, nullable=False)
|
||||
chunk_content = Column(Text, nullable=True)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class HistoryManager:
|
|||
|
||||
return list(reversed(context))
|
||||
|
||||
def create_session(self, user_id: uuid.UUID, title: str = None, rag_config: dict = None) -> ChatSession:
|
||||
def create_session(self, user_id: str, title: str = None, rag_config: dict = None) -> ChatSession:
|
||||
session = ChatSession(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
|
|
@ -48,7 +48,7 @@ class HistoryManager:
|
|||
def get_session(self, session_id: uuid.UUID) -> Optional[ChatSession]:
|
||||
return self.db.query(ChatSession).filter(ChatSession.id == session_id).first()
|
||||
|
||||
def list_sessions(self, user_id: uuid.UUID, skip: int = 0, limit: int = 100) -> List[ChatSession]:
|
||||
def list_sessions(self, user_id: str, skip: int = 0, limit: int = 100) -> List[ChatSession]:
|
||||
return (
|
||||
self.db.query(ChatSession)
|
||||
.filter(ChatSession.user_id == user_id)
|
||||
|
|
@ -80,11 +80,12 @@ class HistoryManager:
|
|||
|
||||
def save_citations(self, message_id: uuid.UUID, citations: List[Dict]):
|
||||
for cit in citations:
|
||||
content = "\n".join(cit.get("content", []))
|
||||
citation = MessageCitation(
|
||||
message_id=message_id,
|
||||
source_doc_id=cit.get("source_doc_id", "unknown"),
|
||||
source_doc_id=cit.get("reference_id", "unknown"),
|
||||
file_path=cit.get("file_path", "unknown"),
|
||||
chunk_content=cit.get("chunk_content"),
|
||||
chunk_content=content,
|
||||
relevance_score=cit.get("relevance_score")
|
||||
)
|
||||
self.db.add(citation)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from dotenv import load_dotenv
|
|||
load_dotenv(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.env"))
|
||||
|
||||
from app.core.database import engine, Base, SessionLocal
|
||||
from app.models.models import User, ChatSession, ChatMessage, MessageCitation # Import models to register them
|
||||
from app.models.models import ChatSession, ChatMessage, MessageCitation # Import models to register them
|
||||
from app.core.config import settings
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
@ -22,39 +22,9 @@ def init_db():
|
|||
logger.info("Tables created successfully!")
|
||||
|
||||
# Create default users from AUTH_ACCOUNTS
|
||||
if settings.AUTH_ACCOUNTS:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
accounts = settings.AUTH_ACCOUNTS.split(',')
|
||||
for account in accounts:
|
||||
if ':' in account:
|
||||
username, password = account.split(':', 1)
|
||||
username = username.strip()
|
||||
# Check if user exists
|
||||
existing_user = db.query(User).filter(User.username == username).first()
|
||||
if not existing_user:
|
||||
logger.info(f"Creating default user: {username}")
|
||||
# Note: In a real app, password should be hashed.
|
||||
# For now, we are just creating the user record.
|
||||
# The User model doesn't have a password field in the provided schema,
|
||||
# so we might need to add it or just store the user for now.
|
||||
# Looking at models.py, User has: username, email, full_name. No password.
|
||||
# I will use username as email for now if email is required.
|
||||
new_user = User(
|
||||
username=username,
|
||||
email=f"{username}@example.com",
|
||||
full_name=username
|
||||
)
|
||||
db.add(new_user)
|
||||
else:
|
||||
logger.info(f"User {username} already exists.")
|
||||
db.commit()
|
||||
logger.info("Default users processed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default users: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
# User table removed, so we don't need to create users anymore.
|
||||
# Logic kept as comment or removed.
|
||||
logger.info("Database initialized.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tables: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue