refactor
This commit is contained in:
parent
005df20558
commit
c9182184cf
20 changed files with 1501 additions and 0 deletions
0
src/api/__init__.py
Normal file
0
src/api/__init__.py
Normal file
80
src/api/auth.py
Normal file
80
src/api/auth.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
async def auth_init(request: Request, auth_service, session_manager):
|
||||
"""Initialize OAuth flow for authentication or data source connection"""
|
||||
try:
|
||||
data = await request.json()
|
||||
provider = data.get("provider")
|
||||
purpose = data.get("purpose", "data_source")
|
||||
connection_name = data.get("name", f"{provider}_{purpose}")
|
||||
redirect_uri = data.get("redirect_uri")
|
||||
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = user.user_id if user else None
|
||||
|
||||
result = await auth_service.init_oauth(
|
||||
provider, purpose, connection_name, redirect_uri, user_id
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse({"error": f"Failed to initialize OAuth: {str(e)}"}, status_code=500)
|
||||
|
||||
async def auth_callback(request: Request, auth_service, session_manager):
|
||||
"""Handle OAuth callback - exchange authorization code for tokens"""
|
||||
try:
|
||||
data = await request.json()
|
||||
connection_id = data.get("connection_id")
|
||||
authorization_code = data.get("authorization_code")
|
||||
state = data.get("state")
|
||||
|
||||
result = await auth_service.handle_oauth_callback(
|
||||
connection_id, authorization_code, state
|
||||
)
|
||||
|
||||
# If this is app auth, set JWT cookie
|
||||
if result.get("purpose") == "app_auth" and result.get("jwt_token"):
|
||||
response = JSONResponse({
|
||||
k: v for k, v in result.items() if k != "jwt_token"
|
||||
})
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=result["jwt_token"],
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="lax",
|
||||
max_age=7 * 24 * 60 * 60 # 7 days
|
||||
)
|
||||
return response
|
||||
else:
|
||||
return JSONResponse(result)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse({"error": f"Callback failed: {str(e)}"}, status_code=500)
|
||||
|
||||
async def auth_me(request: Request, auth_service, session_manager):
|
||||
"""Get current user information"""
|
||||
result = await auth_service.get_user_info(request)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def auth_logout(request: Request, auth_service, session_manager):
|
||||
"""Logout user by clearing auth cookie"""
|
||||
response = JSONResponse({
|
||||
"status": "logged_out",
|
||||
"message": "Successfully logged out"
|
||||
})
|
||||
|
||||
# Clear the auth cookie
|
||||
response.delete_cookie(
|
||||
key="auth_token",
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return response
|
||||
59
src/api/chat.py
Normal file
59
src/api/chat.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||
"""Handle chat requests"""
|
||||
data = await request.json()
|
||||
prompt = data.get("prompt", "")
|
||||
previous_response_id = data.get("previous_response_id")
|
||||
stream = data.get("stream", False)
|
||||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
||||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "Cache-Control"
|
||||
}
|
||||
)
|
||||
else:
|
||||
result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||
"""Handle Langflow chat requests"""
|
||||
data = await request.json()
|
||||
prompt = data.get("prompt", "")
|
||||
previous_response_id = data.get("previous_response_id")
|
||||
stream = data.get("stream", False)
|
||||
|
||||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=True),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "Cache-Control"
|
||||
}
|
||||
)
|
||||
else:
|
||||
result = await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=False)
|
||||
return JSONResponse(result)
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": f"Langflow request failed: {str(e)}"}, status_code=500)
|
||||
81
src/api/connectors.py
Normal file
81
src/api/connectors.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
async def connector_sync(request: Request, connector_service, session_manager):
|
||||
"""Sync files from a connector connection"""
|
||||
data = await request.json()
|
||||
connection_id = data.get("connection_id")
|
||||
max_files = data.get("max_files")
|
||||
|
||||
if not connection_id:
|
||||
return JSONResponse({"error": "connection_id is required"}, status_code=400)
|
||||
|
||||
try:
|
||||
print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}")
|
||||
|
||||
# Verify user owns this connection
|
||||
user = request.state.user
|
||||
print(f"[DEBUG] User: {user.user_id}")
|
||||
|
||||
connection_config = await connector_service.connection_manager.get_connection(connection_id)
|
||||
print(f"[DEBUG] Got connection config: {connection_config is not None}")
|
||||
|
||||
if not connection_config:
|
||||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||||
|
||||
if connection_config.user_id != user.user_id:
|
||||
return JSONResponse({"error": "Access denied"}, status_code=403)
|
||||
|
||||
print(f"[DEBUG] About to call sync_connector_files")
|
||||
task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files)
|
||||
print(f"[DEBUG] Got task_id: {task_id}")
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": task_id,
|
||||
"status": "sync_started",
|
||||
"message": f"Started syncing files from connection {connection_id}"
|
||||
},
|
||||
status_code=201
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
error_msg = f"[ERROR] Connector sync failed: {str(e)}"
|
||||
print(error_msg, file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500)
|
||||
|
||||
async def connector_status(request: Request, connector_service, session_manager):
|
||||
"""Get connector status for authenticated user"""
|
||||
connector_type = request.path_params.get("connector_type", "google_drive")
|
||||
user = request.state.user
|
||||
|
||||
# Get connections for this connector type and user
|
||||
connections = await connector_service.connection_manager.list_connections(
|
||||
user_id=user.user_id,
|
||||
connector_type=connector_type
|
||||
)
|
||||
|
||||
# Check if there are any active connections
|
||||
active_connections = [conn for conn in connections if conn.is_active]
|
||||
has_authenticated_connection = len(active_connections) > 0
|
||||
|
||||
return JSONResponse({
|
||||
"connector_type": connector_type,
|
||||
"authenticated": has_authenticated_connection,
|
||||
"status": "connected" if has_authenticated_connection else "not_connected",
|
||||
"connections": [
|
||||
{
|
||||
"connection_id": conn.connection_id,
|
||||
"name": conn.name,
|
||||
"is_active": conn.is_active,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None
|
||||
}
|
||||
for conn in connections
|
||||
]
|
||||
})
|
||||
13
src/api/search.py
Normal file
13
src/api/search.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
async def search(request: Request, search_service, session_manager):
|
||||
"""Search for documents"""
|
||||
payload = await request.json()
|
||||
query = payload.get("query")
|
||||
if not query:
|
||||
return JSONResponse({"error": "Query is required"}, status_code=400)
|
||||
|
||||
user = request.state.user
|
||||
result = await search_service.search(query, user_id=user.user_id)
|
||||
return JSONResponse(result)
|
||||
78
src/api/upload.py
Normal file
78
src/api/upload.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
async def upload(request: Request, document_service, session_manager):
|
||||
"""Upload a single file"""
|
||||
form = await request.form()
|
||||
upload_file = form["file"]
|
||||
user = request.state.user
|
||||
|
||||
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def upload_path(request: Request, task_service, session_manager):
|
||||
"""Upload all files from a directory path"""
|
||||
payload = await request.json()
|
||||
base_dir = payload.get("path")
|
||||
if not base_dir or not os.path.isdir(base_dir):
|
||||
return JSONResponse({"error": "Invalid path"}, status_code=400)
|
||||
|
||||
file_paths = [os.path.join(root, fn)
|
||||
for root, _, files in os.walk(base_dir)
|
||||
for fn in files]
|
||||
|
||||
if not file_paths:
|
||||
return JSONResponse({"error": "No files found in directory"}, status_code=400)
|
||||
|
||||
user = request.state.user
|
||||
task_id = await task_service.create_upload_task(user.user_id, file_paths)
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": task_id,
|
||||
"total_files": len(file_paths),
|
||||
"status": "accepted"
|
||||
}, status_code=201)
|
||||
|
||||
async def upload_context(request: Request, document_service, chat_service, session_manager):
|
||||
"""Upload a file and add its content as context to the current conversation"""
|
||||
form = await request.form()
|
||||
upload_file = form["file"]
|
||||
filename = upload_file.filename or "uploaded_document"
|
||||
|
||||
# Get optional parameters
|
||||
previous_response_id = form.get("previous_response_id")
|
||||
endpoint = form.get("endpoint", "langflow")
|
||||
|
||||
# Process document and extract content
|
||||
doc_result = await document_service.process_upload_context(upload_file, filename)
|
||||
|
||||
# Send document content as user message to get proper response_id
|
||||
response_text, response_id = await chat_service.upload_context_chat(
|
||||
doc_result["content"],
|
||||
filename,
|
||||
previous_response_id=previous_response_id,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"status": "context_added",
|
||||
"filename": doc_result["filename"],
|
||||
"pages": doc_result["pages"],
|
||||
"content_length": doc_result["content_length"],
|
||||
"response_id": response_id,
|
||||
"confirmation": response_text
|
||||
}
|
||||
|
||||
return JSONResponse(response_data)
|
||||
|
||||
async def task_status(request: Request, task_service, session_manager):
|
||||
"""Get the status of an upload task"""
|
||||
task_id = request.path_params.get("task_id")
|
||||
user = request.state.user
|
||||
|
||||
task_status_result = task_service.get_task_status(user.user_id, task_id)
|
||||
if not task_status_result:
|
||||
return JSONResponse({"error": "Task not found"}, status_code=404)
|
||||
|
||||
return JSONResponse(task_status_result)
|
||||
0
src/config/__init__.py
Normal file
0
src/config/__init__.py
Normal file
105
src/config/settings.py
Normal file
105
src/config/settings.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import os
|
||||
from dotenv import load_dotenv
|
||||
from opensearchpy import AsyncOpenSearch
|
||||
from opensearchpy._async.http_aiohttp import AIOHttpConnection
|
||||
from docling.document_converter import DocumentConverter
|
||||
from agentd.patch import patch_openai_with_mcp
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
load_dotenv()
|
||||
load_dotenv("../")
|
||||
|
||||
# Environment variables
|
||||
OPENSEARCH_HOST = os.getenv("OPENSEARCH_HOST", "localhost")
|
||||
OPENSEARCH_PORT = int(os.getenv("OPENSEARCH_PORT", "9200"))
|
||||
OPENSEARCH_USERNAME = os.getenv("OPENSEARCH_USERNAME", "admin")
|
||||
OPENSEARCH_PASSWORD = os.getenv("OPENSEARCH_PASSWORD")
|
||||
LANGFLOW_URL = os.getenv("LANGFLOW_URL", "http://localhost:7860")
|
||||
FLOW_ID = os.getenv("FLOW_ID")
|
||||
LANGFLOW_KEY = os.getenv("LANGFLOW_SECRET_KEY")
|
||||
SESSION_SECRET = os.getenv("SESSION_SECRET", "your-secret-key-change-in-production")
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
# OpenSearch configuration
|
||||
INDEX_NAME = "documents"
|
||||
VECTOR_DIM = 1536
|
||||
EMBED_MODEL = "text-embedding-3-small"
|
||||
|
||||
INDEX_BODY = {
|
||||
"settings": {
|
||||
"index": {"knn": True},
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 1
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"document_id": { "type": "keyword" },
|
||||
"filename": { "type": "keyword" },
|
||||
"mimetype": { "type": "keyword" },
|
||||
"page": { "type": "integer" },
|
||||
"text": { "type": "text" },
|
||||
"chunk_embedding": {
|
||||
"type": "knn_vector",
|
||||
"dimension": VECTOR_DIM,
|
||||
"method": {
|
||||
"name": "disk_ann",
|
||||
"engine": "jvector",
|
||||
"space_type": "l2",
|
||||
"parameters": {
|
||||
"ef_construction": 100,
|
||||
"m": 16
|
||||
}
|
||||
}
|
||||
},
|
||||
"source_url": { "type": "keyword" },
|
||||
"connector_type": { "type": "keyword" },
|
||||
"owner": { "type": "keyword" },
|
||||
"allowed_users": { "type": "keyword" },
|
||||
"allowed_groups": { "type": "keyword" },
|
||||
"user_permissions": { "type": "object" },
|
||||
"group_permissions": { "type": "object" },
|
||||
"created_time": { "type": "date" },
|
||||
"modified_time": { "type": "date" },
|
||||
"indexed_time": { "type": "date" },
|
||||
"metadata": { "type": "object" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AppClients:
|
||||
def __init__(self):
|
||||
self.opensearch = None
|
||||
self.langflow_client = None
|
||||
self.patched_async_client = None
|
||||
self.converter = None
|
||||
|
||||
def initialize(self):
|
||||
# Initialize OpenSearch client
|
||||
self.opensearch = AsyncOpenSearch(
|
||||
hosts=[{"host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT}],
|
||||
connection_class=AIOHttpConnection,
|
||||
scheme="https",
|
||||
use_ssl=True,
|
||||
verify_certs=False,
|
||||
ssl_assert_fingerprint=None,
|
||||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
http_compress=True,
|
||||
)
|
||||
|
||||
# Initialize Langflow client
|
||||
self.langflow_client = AsyncOpenAI(
|
||||
base_url=f"{LANGFLOW_URL}/api/v1",
|
||||
api_key=LANGFLOW_KEY
|
||||
)
|
||||
|
||||
# Initialize patched OpenAI client
|
||||
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
|
||||
|
||||
# Initialize Docling converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
return self
|
||||
|
||||
# Global clients instance
|
||||
clients = AppClients()
|
||||
234
src/main.py
Normal file
234
src/main.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import asyncio
|
||||
import atexit
|
||||
import torch
|
||||
from functools import partial
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
# Configuration and setup
|
||||
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
|
||||
from utils.gpu_detection import detect_gpu_devices
|
||||
|
||||
# Services
|
||||
from services.document_service import DocumentService
|
||||
from services.search_service import SearchService
|
||||
from services.task_service import TaskService
|
||||
from services.auth_service import AuthService
|
||||
from services.chat_service import ChatService
|
||||
|
||||
# Existing services
|
||||
from connectors.service import ConnectorService
|
||||
from session_manager import SessionManager
|
||||
from auth_middleware import require_auth, optional_auth
|
||||
|
||||
# API endpoints
|
||||
from api import upload, search, chat, auth, connectors
|
||||
|
||||
print("CUDA available:", torch.cuda.is_available())
|
||||
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||||
|
||||
async def wait_for_opensearch():
|
||||
"""Wait for OpenSearch to be ready with retries"""
|
||||
max_retries = 30
|
||||
retry_delay = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await clients.opensearch.info()
|
||||
print("OpenSearch is ready!")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})")
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
raise Exception("OpenSearch failed to become ready")
|
||||
|
||||
async def init_index():
|
||||
"""Initialize OpenSearch index"""
|
||||
await wait_for_opensearch()
|
||||
|
||||
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
||||
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
|
||||
print(f"Created index '{INDEX_NAME}'")
|
||||
else:
|
||||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||||
|
||||
def initialize_services():
|
||||
"""Initialize all services and their dependencies"""
|
||||
# Initialize clients
|
||||
clients.initialize()
|
||||
|
||||
# Initialize session manager
|
||||
session_manager = SessionManager(SESSION_SECRET)
|
||||
|
||||
# Initialize services
|
||||
document_service = DocumentService()
|
||||
search_service = SearchService()
|
||||
task_service = TaskService(document_service)
|
||||
chat_service = ChatService()
|
||||
|
||||
# Set process pool for document service
|
||||
document_service.process_pool = task_service.process_pool
|
||||
|
||||
# Initialize connector service
|
||||
connector_service = ConnectorService(
|
||||
opensearch_client=clients.opensearch,
|
||||
patched_async_client=clients.patched_async_client,
|
||||
process_pool=task_service.process_pool,
|
||||
embed_model="text-embedding-3-small",
|
||||
index_name=INDEX_NAME
|
||||
)
|
||||
|
||||
# Initialize auth service
|
||||
auth_service = AuthService(session_manager, connector_service)
|
||||
|
||||
return {
|
||||
'document_service': document_service,
|
||||
'search_service': search_service,
|
||||
'task_service': task_service,
|
||||
'chat_service': chat_service,
|
||||
'auth_service': auth_service,
|
||||
'connector_service': connector_service,
|
||||
'session_manager': session_manager
|
||||
}
|
||||
|
||||
def create_app():
|
||||
"""Create and configure the Starlette application"""
|
||||
services = initialize_services()
|
||||
|
||||
# Create route handlers with service dependencies injected
|
||||
routes = [
|
||||
# Upload endpoints
|
||||
Route("/upload",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(upload.upload,
|
||||
document_service=services['document_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/upload_context",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(upload.upload_context,
|
||||
document_service=services['document_service'],
|
||||
chat_service=services['chat_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/upload_path",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(upload.upload_path,
|
||||
task_service=services['task_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/tasks/{task_id}",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(upload.task_status,
|
||||
task_service=services['task_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
|
||||
# Search endpoint
|
||||
Route("/search",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(search.search,
|
||||
search_service=services['search_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
# Chat endpoints
|
||||
Route("/chat",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(chat.chat_endpoint,
|
||||
chat_service=services['chat_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/langflow",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(chat.langflow_endpoint,
|
||||
chat_service=services['chat_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
# Authentication endpoints
|
||||
Route("/auth/init",
|
||||
optional_auth(services['session_manager'])(
|
||||
partial(auth.auth_init,
|
||||
auth_service=services['auth_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/auth/callback",
|
||||
partial(auth.auth_callback,
|
||||
auth_service=services['auth_service'],
|
||||
session_manager=services['session_manager']),
|
||||
methods=["POST"]),
|
||||
|
||||
Route("/auth/me",
|
||||
optional_auth(services['session_manager'])(
|
||||
partial(auth.auth_me,
|
||||
auth_service=services['auth_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
|
||||
Route("/auth/logout",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(auth.auth_logout,
|
||||
auth_service=services['auth_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
# Connector endpoints
|
||||
Route("/connectors/sync",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.connector_sync,
|
||||
connector_service=services['connector_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/connectors/status/{connector_type}",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.connector_status,
|
||||
connector_service=services['connector_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
]
|
||||
|
||||
app = Starlette(debug=True, routes=routes)
|
||||
app.state.services = services # Store services for cleanup
|
||||
|
||||
return app
|
||||
|
||||
async def startup():
|
||||
"""Application startup tasks"""
|
||||
await init_index()
|
||||
# Get services from app state if needed for initialization
|
||||
# services = app.state.services
|
||||
# await services['connector_service'].initialize()
|
||||
|
||||
def cleanup():
|
||||
"""Cleanup on application shutdown"""
|
||||
# This will be called on exit to cleanup process pools
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Register cleanup function
|
||||
atexit.register(cleanup)
|
||||
|
||||
# Create app
|
||||
app = create_app()
|
||||
|
||||
# Run startup tasks
|
||||
asyncio.run(startup())
|
||||
|
||||
# Run the server
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False, # Disable reload since we're running from main
|
||||
)
|
||||
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
32
src/models/tasks.py
Normal file
32
src/models/tasks.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class FileTask:
|
||||
file_path: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
@dataclass
|
||||
class UploadTask:
|
||||
task_id: str
|
||||
total_files: int
|
||||
processed_files: int = 0
|
||||
successful_files: int = 0
|
||||
failed_files: int = 0
|
||||
file_tasks: Dict[str, FileTask] = field(default_factory=dict)
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
0
src/services/__init__.py
Normal file
0
src/services/__init__.py
Normal file
213
src/services/auth_service.py
Normal file
213
src/services/auth_service.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
import os
|
||||
import uuid
|
||||
import json
|
||||
import httpx
|
||||
import aiofiles
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from session_manager import SessionManager
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, session_manager: SessionManager, connector_service=None):
|
||||
self.session_manager = session_manager
|
||||
self.connector_service = connector_service
|
||||
self.used_auth_codes = set() # Track used authorization codes
|
||||
|
||||
async def init_oauth(self, provider: str, purpose: str, connection_name: str,
|
||||
redirect_uri: str, user_id: str = None) -> dict:
|
||||
"""Initialize OAuth flow for authentication or data source connection"""
|
||||
if provider != "google":
|
||||
raise ValueError("Unsupported provider")
|
||||
|
||||
if not redirect_uri:
|
||||
raise ValueError("redirect_uri is required")
|
||||
|
||||
if not GOOGLE_OAUTH_CLIENT_ID:
|
||||
raise ValueError("Google OAuth client ID not configured")
|
||||
|
||||
# Create connection configuration
|
||||
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||||
config = {
|
||||
"client_id": GOOGLE_OAUTH_CLIENT_ID,
|
||||
"token_file": token_file,
|
||||
"provider": provider,
|
||||
"purpose": purpose,
|
||||
"redirect_uri": redirect_uri
|
||||
}
|
||||
|
||||
# Create connection in manager
|
||||
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
|
||||
connection_id = await self.connector_service.connection_manager.create_connection(
|
||||
connector_type=connector_type,
|
||||
name=connection_name,
|
||||
config=config,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Return OAuth configuration for client-side flow
|
||||
scopes = [
|
||||
'openid', 'email', 'profile',
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
oauth_config = {
|
||||
"client_id": GOOGLE_OAUTH_CLIENT_ID,
|
||||
"scopes": scopes,
|
||||
"redirect_uri": redirect_uri,
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
|
||||
return {
|
||||
"connection_id": connection_id,
|
||||
"oauth_config": oauth_config
|
||||
}
|
||||
|
||||
async def handle_oauth_callback(self, connection_id: str, authorization_code: str,
|
||||
state: str = None) -> dict:
|
||||
"""Handle OAuth callback - exchange authorization code for tokens"""
|
||||
if not all([connection_id, authorization_code]):
|
||||
raise ValueError("Missing required parameters (connection_id, authorization_code)")
|
||||
|
||||
# Check if authorization code has already been used
|
||||
if authorization_code in self.used_auth_codes:
|
||||
raise ValueError("Authorization code already used")
|
||||
|
||||
# Mark code as used to prevent duplicate requests
|
||||
self.used_auth_codes.add(authorization_code)
|
||||
|
||||
try:
|
||||
# Get connection config
|
||||
connection_config = await self.connector_service.connection_manager.get_connection(connection_id)
|
||||
if not connection_config:
|
||||
raise ValueError("Connection not found")
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
redirect_uri = connection_config.config.get("redirect_uri")
|
||||
if not redirect_uri:
|
||||
raise ValueError("Redirect URI not found in connection config")
|
||||
|
||||
token_url = "https://oauth2.googleapis.com/token"
|
||||
token_payload = {
|
||||
"code": authorization_code,
|
||||
"client_id": connection_config.config["client_id"],
|
||||
"client_secret": GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(token_url, data=token_payload)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
raise Exception(f"Token exchange failed: {token_response.text}")
|
||||
|
||||
token_data = token_response.json()
|
||||
|
||||
# Store tokens in the token file
|
||||
token_file_data = {
|
||||
"token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"scopes": [
|
||||
"openid", "email", "profile",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly"
|
||||
]
|
||||
}
|
||||
|
||||
# Add expiry if provided
|
||||
if token_data.get("expires_in"):
|
||||
expiry = datetime.now() + timedelta(seconds=int(token_data["expires_in"]))
|
||||
token_file_data["expiry"] = expiry.isoformat()
|
||||
|
||||
# Save tokens to file
|
||||
token_file_path = connection_config.config["token_file"]
|
||||
async with aiofiles.open(token_file_path, 'w') as f:
|
||||
await f.write(json.dumps(token_file_data, indent=2))
|
||||
|
||||
# Route based on purpose
|
||||
purpose = connection_config.config.get("purpose", "data_source")
|
||||
|
||||
if purpose == "app_auth":
|
||||
return await self._handle_app_auth(connection_id, connection_config, token_data)
|
||||
else:
|
||||
return await self._handle_data_source_auth(connection_id, connection_config)
|
||||
|
||||
except Exception as e:
|
||||
# Remove used code from set if we failed
|
||||
self.used_auth_codes.discard(authorization_code)
|
||||
raise e
|
||||
|
||||
async def _handle_app_auth(self, connection_id: str, connection_config, token_data: dict) -> dict:
|
||||
"""Handle app authentication - create user session"""
|
||||
jwt_token = await self.session_manager.create_user_session(token_data["access_token"])
|
||||
|
||||
if jwt_token:
|
||||
# Get the user info to create a persistent Google Drive connection
|
||||
user_info = await self.session_manager.get_user_info_from_token(token_data["access_token"])
|
||||
user_id = user_info["id"] if user_info else None
|
||||
|
||||
response_data = {
|
||||
"status": "authenticated",
|
||||
"purpose": "app_auth",
|
||||
"redirect": "/",
|
||||
"jwt_token": jwt_token # Include JWT token in response
|
||||
}
|
||||
|
||||
if user_id:
|
||||
# Convert the temporary auth connection to a persistent Google Drive connection
|
||||
await self.connector_service.connection_manager.update_connection(
|
||||
connection_id=connection_id,
|
||||
connector_type="google_drive",
|
||||
name=f"Google Drive ({user_info.get('email', 'Unknown')})",
|
||||
user_id=user_id,
|
||||
config={
|
||||
**connection_config.config,
|
||||
"purpose": "data_source",
|
||||
"user_email": user_info.get("email")
|
||||
}
|
||||
)
|
||||
response_data["google_drive_connection_id"] = connection_id
|
||||
else:
|
||||
# Fallback: delete connection if we can't get user info
|
||||
await self.connector_service.connection_manager.delete_connection(connection_id)
|
||||
|
||||
return response_data
|
||||
else:
|
||||
# Clean up connection if session creation failed
|
||||
await self.connector_service.connection_manager.delete_connection(connection_id)
|
||||
raise Exception("Failed to create user session")
|
||||
|
||||
async def _handle_data_source_auth(self, connection_id: str, connection_config) -> dict:
|
||||
"""Handle data source connection - keep the connection for syncing"""
|
||||
return {
|
||||
"status": "authenticated",
|
||||
"connection_id": connection_id,
|
||||
"purpose": "data_source",
|
||||
"connector_type": connection_config.connector_type
|
||||
}
|
||||
|
||||
async def get_user_info(self, request) -> Optional[dict]:
|
||||
"""Get current user information from request"""
|
||||
user = getattr(request.state, 'user', None)
|
||||
|
||||
if user:
|
||||
return {
|
||||
"authenticated": True,
|
||||
"user": {
|
||||
"user_id": user.user_id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"provider": user.provider,
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"authenticated": False,
|
||||
"user": None
|
||||
}
|
||||
47
src/services/chat_service.py
Normal file
47
src/services/chat_service.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
|
||||
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
|
||||
|
||||
class ChatService:
|
||||
|
||||
async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False):
|
||||
"""Handle chat requests using the patched OpenAI client"""
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
if stream:
|
||||
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
||||
else:
|
||||
response_text, response_id = await async_chat(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
||||
response_data = {"response": response_text}
|
||||
if response_id:
|
||||
response_data["response_id"] = response_id
|
||||
return response_data
|
||||
|
||||
async def langflow_chat(self, prompt: str, previous_response_id: str = None, stream: bool = False):
|
||||
"""Handle Langflow chat requests"""
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
if not LANGFLOW_URL or not FLOW_ID or not LANGFLOW_KEY:
|
||||
raise ValueError("LANGFLOW_URL, FLOW_ID, and LANGFLOW_KEY environment variables are required")
|
||||
|
||||
if stream:
|
||||
return async_langflow_stream(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id)
|
||||
else:
|
||||
response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id)
|
||||
response_data = {"response": response_text}
|
||||
if response_id:
|
||||
response_data["response_id"] = response_id
|
||||
return response_data
|
||||
|
||||
async def upload_context_chat(self, document_content: str, filename: str,
|
||||
previous_response_id: str = None, endpoint: str = "langflow"):
|
||||
"""Send document content as user message to get proper response_id"""
|
||||
document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{document_content}\n\nPlease confirm you've received this document and are ready to answer questions about it."
|
||||
|
||||
if endpoint == "langflow":
|
||||
response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, document_prompt, previous_response_id=previous_response_id)
|
||||
else: # chat
|
||||
response_text, response_id = await async_chat(clients.patched_async_client, document_prompt, previous_response_id=previous_response_id)
|
||||
|
||||
return response_text, response_id
|
||||
184
src/services/document_service.py
Normal file
184
src/services/document_service.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import tempfile
|
||||
import os
|
||||
import aiofiles
|
||||
from io import BytesIO
|
||||
from docling_core.types.io import DocumentStream
|
||||
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from utils.document_processing import extract_relevant, process_document_sync
|
||||
|
||||
class DocumentService:
|
||||
def __init__(self, process_pool=None):
|
||||
self.process_pool = process_pool
|
||||
|
||||
async def process_file_common(self, file_path: str, file_hash: str = None, owner_user_id: str = None):
|
||||
"""
|
||||
Common processing logic for both upload and upload_path.
|
||||
1. Optionally compute SHA256 hash if not provided.
|
||||
2. Convert with docling and extract relevant content.
|
||||
3. Add embeddings.
|
||||
4. Index into OpenSearch.
|
||||
"""
|
||||
if file_hash is None:
|
||||
sha256 = hashlib.sha256()
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
while True:
|
||||
chunk = await f.read(1 << 20)
|
||||
if not chunk:
|
||||
break
|
||||
sha256.update(chunk)
|
||||
file_hash = sha256.hexdigest()
|
||||
|
||||
exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||
if exists:
|
||||
return {"status": "unchanged", "id": file_hash}
|
||||
|
||||
# convert and extract
|
||||
result = clients.converter.convert(file_path)
|
||||
full_doc = result.document.export_to_dict()
|
||||
slim_doc = extract_relevant(full_doc)
|
||||
|
||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||
embeddings = [d.embedding for d in resp.data]
|
||||
|
||||
# Index each chunk as a separate document
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
chunk_doc = {
|
||||
"document_id": file_hash,
|
||||
"filename": slim_doc["filename"],
|
||||
"mimetype": slim_doc["mimetype"],
|
||||
"page": chunk["page"],
|
||||
"text": chunk["text"],
|
||||
"chunk_embedding": vect,
|
||||
"owner": owner_user_id,
|
||||
"indexed_time": datetime.datetime.now().isoformat()
|
||||
}
|
||||
chunk_id = f"{file_hash}_{i}"
|
||||
await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||
return {"status": "indexed", "id": file_hash}
|
||||
|
||||
async def process_upload_file(self, upload_file, owner_user_id: str = None):
|
||||
"""Process an uploaded file from form data"""
|
||||
sha256 = hashlib.sha256()
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
while True:
|
||||
chunk = await upload_file.read(1 << 20)
|
||||
if not chunk:
|
||||
break
|
||||
sha256.update(chunk)
|
||||
tmp.write(chunk)
|
||||
tmp.flush()
|
||||
|
||||
file_hash = sha256.hexdigest()
|
||||
exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||
if exists:
|
||||
return {"status": "unchanged", "id": file_hash}
|
||||
|
||||
result = await self.process_file_common(tmp.name, file_hash, owner_user_id=owner_user_id)
|
||||
return result
|
||||
|
||||
finally:
|
||||
tmp.close()
|
||||
os.remove(tmp.name)
|
||||
|
||||
async def process_upload_context(self, upload_file, filename: str = None):
|
||||
"""Process uploaded file and return content for context"""
|
||||
import io
|
||||
|
||||
if not filename:
|
||||
filename = upload_file.filename or "uploaded_document"
|
||||
|
||||
# Stream file content into BytesIO
|
||||
content = io.BytesIO()
|
||||
while True:
|
||||
chunk = await upload_file.read(1 << 20) # 1MB chunks
|
||||
if not chunk:
|
||||
break
|
||||
content.write(chunk)
|
||||
content.seek(0) # Reset to beginning for reading
|
||||
|
||||
# Create DocumentStream and process with docling
|
||||
doc_stream = DocumentStream(name=filename, stream=content)
|
||||
result = clients.converter.convert(doc_stream)
|
||||
full_doc = result.document.export_to_dict()
|
||||
slim_doc = extract_relevant(full_doc)
|
||||
|
||||
# Extract all text content
|
||||
all_text = []
|
||||
for chunk in slim_doc["chunks"]:
|
||||
all_text.append(f"Page {chunk['page']}:\n{chunk['text']}")
|
||||
|
||||
full_content = "\n\n".join(all_text)
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": full_content,
|
||||
"pages": len(slim_doc["chunks"]),
|
||||
"content_length": len(full_content)
|
||||
}
|
||||
|
||||
async def process_single_file_task(self, upload_task, file_path: str):
|
||||
"""Process a single file and update task tracking - used by task service"""
|
||||
from models.tasks import TaskStatus
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
file_task = upload_task.file_tasks[file_path]
|
||||
file_task.status = TaskStatus.RUNNING
|
||||
file_task.updated_at = time.time()
|
||||
|
||||
try:
|
||||
# Check if file already exists in index
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run CPU-intensive docling processing in separate process
|
||||
slim_doc = await loop.run_in_executor(self.process_pool, process_document_sync, file_path)
|
||||
|
||||
# Check if already indexed
|
||||
exists = await clients.opensearch.exists(index=INDEX_NAME, id=slim_doc["id"])
|
||||
if exists:
|
||||
result = {"status": "unchanged", "id": slim_doc["id"]}
|
||||
else:
|
||||
# Generate embeddings and index (I/O bound, keep in main process)
|
||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||
embeddings = [d.embedding for d in resp.data]
|
||||
|
||||
# Index each chunk
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
chunk_doc = {
|
||||
"document_id": slim_doc["id"],
|
||||
"filename": slim_doc["filename"],
|
||||
"mimetype": slim_doc["mimetype"],
|
||||
"page": chunk["page"],
|
||||
"text": chunk["text"],
|
||||
"chunk_embedding": vect
|
||||
}
|
||||
chunk_id = f"{slim_doc['id']}_{i}"
|
||||
await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||
|
||||
result = {"status": "indexed", "id": slim_doc["id"]}
|
||||
|
||||
result["path"] = file_path
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
upload_task.successful_files += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to process file {file_path}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
file_task.status = TaskStatus.FAILED
|
||||
file_task.error = str(e)
|
||||
upload_task.failed_files += 1
|
||||
finally:
|
||||
file_task.updated_at = time.time()
|
||||
upload_task.processed_files += 1
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
if upload_task.processed_files >= upload_task.total_files:
|
||||
upload_task.status = TaskStatus.COMPLETED
|
||||
80
src/services/search_service.py
Normal file
80
src/services/search_service.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from agentd.tool_decorator import tool
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
|
||||
class SearchService:
|
||||
|
||||
@tool
|
||||
async def search_tool(self, query: str, user_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Use this tool to search for documents relevant to the query.
|
||||
|
||||
Args:
|
||||
query (str): query string to search the corpus
|
||||
user_id (str): user ID for access control (optional)
|
||||
|
||||
Returns:
|
||||
dict (str, Any): {"results": [chunks]} on success
|
||||
"""
|
||||
# Embed the query
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
||||
query_embedding = resp.data[0].embedding
|
||||
|
||||
# Base query structure
|
||||
search_body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"knn": {
|
||||
"chunk_embedding": {
|
||||
"vector": query_embedding,
|
||||
"k": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"],
|
||||
"size": 10
|
||||
}
|
||||
|
||||
# Require authentication - no anonymous access to search
|
||||
if not user_id:
|
||||
return {"results": [], "error": "Authentication required"}
|
||||
|
||||
# Authenticated user access control
|
||||
# User can access documents if:
|
||||
# 1. They own the document (owner field matches user_id)
|
||||
# 2. They're in allowed_users list
|
||||
# 3. Document has no ACL (public documents)
|
||||
# TODO: Add group access control later
|
||||
should_clauses = [
|
||||
{"term": {"owner": user_id}},
|
||||
{"term": {"allowed_users": user_id}},
|
||||
{"bool": {"must_not": {"exists": {"field": "owner"}}}} # Public docs
|
||||
]
|
||||
|
||||
search_body["query"]["bool"]["should"] = should_clauses
|
||||
search_body["query"]["bool"]["minimum_should_match"] = 1
|
||||
|
||||
results = await clients.opensearch.search(index=INDEX_NAME, body=search_body)
|
||||
|
||||
# Transform results
|
||||
chunks = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
chunks.append({
|
||||
"filename": hit["_source"]["filename"],
|
||||
"mimetype": hit["_source"]["mimetype"],
|
||||
"page": hit["_source"]["page"],
|
||||
"text": hit["_source"]["text"],
|
||||
"score": hit["_score"],
|
||||
"source_url": hit["_source"].get("source_url"),
|
||||
"owner": hit["_source"].get("owner")
|
||||
})
|
||||
return {"results": chunks}
|
||||
|
||||
async def search(self, query: str, user_id: str = None) -> Dict[str, Any]:
|
||||
"""Public search method for API endpoints"""
|
||||
return await self.search_tool(query, user_id)
|
||||
112
src/services/task_service.py
Normal file
112
src/services/task_service.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
import random
|
||||
from typing import Dict
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from models.tasks import TaskStatus, UploadTask, FileTask
|
||||
from utils.gpu_detection import get_worker_count
|
||||
|
||||
class TaskService:
|
||||
def __init__(self, document_service=None):
|
||||
self.document_service = document_service
|
||||
self.task_store: Dict[str, Dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask}
|
||||
self.background_tasks = set()
|
||||
|
||||
# Initialize process pool
|
||||
max_workers = get_worker_count()
|
||||
self.process_pool = ProcessPoolExecutor(max_workers=max_workers)
|
||||
print(f"Process pool initialized with {max_workers} workers")
|
||||
|
||||
async def exponential_backoff_delay(self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
||||
"""Apply exponential backoff with jitter"""
|
||||
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def create_upload_task(self, user_id: str, file_paths: list) -> str:
|
||||
"""Create a new upload task for bulk file processing"""
|
||||
task_id = str(uuid.uuid4())
|
||||
upload_task = UploadTask(
|
||||
task_id=task_id,
|
||||
total_files=len(file_paths),
|
||||
file_tasks={path: FileTask(file_path=path) for path in file_paths}
|
||||
)
|
||||
|
||||
if user_id not in self.task_store:
|
||||
self.task_store[user_id] = {}
|
||||
self.task_store[user_id][task_id] = upload_task
|
||||
|
||||
# Start background processing
|
||||
background_task = asyncio.create_task(self.background_upload_processor(user_id, task_id))
|
||||
self.background_tasks.add(background_task)
|
||||
background_task.add_done_callback(self.background_tasks.discard)
|
||||
|
||||
return task_id
|
||||
|
||||
async def background_upload_processor(self, user_id: str, task_id: str) -> None:
|
||||
"""Background task to process all files in an upload job with concurrency control"""
|
||||
try:
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
upload_task.status = TaskStatus.RUNNING
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
# Process files with limited concurrency to avoid overwhelming the system
|
||||
max_workers = get_worker_count()
|
||||
semaphore = asyncio.Semaphore(max_workers * 2) # Allow 2x process pool size for async I/O
|
||||
|
||||
async def process_with_semaphore(file_path: str):
|
||||
async with semaphore:
|
||||
await self.document_service.process_single_file_task(upload_task, file_path)
|
||||
|
||||
tasks = [
|
||||
process_with_semaphore(file_path)
|
||||
for file_path in upload_task.file_tasks.keys()
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||
self.task_store[user_id][task_id].updated_at = time.time()
|
||||
|
||||
def get_task_status(self, user_id: str, task_id: str) -> dict:
|
||||
"""Get the status of a specific upload task"""
|
||||
if (not task_id or
|
||||
user_id not in self.task_store or
|
||||
task_id not in self.task_store[user_id]):
|
||||
return None
|
||||
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
|
||||
file_statuses = {}
|
||||
for file_path, file_task in upload_task.file_tasks.items():
|
||||
file_statuses[file_path] = {
|
||||
"status": file_task.status.value,
|
||||
"result": file_task.result,
|
||||
"error": file_task.error,
|
||||
"retry_count": file_task.retry_count,
|
||||
"created_at": file_task.created_at,
|
||||
"updated_at": file_task.updated_at
|
||||
}
|
||||
|
||||
return {
|
||||
"task_id": upload_task.task_id,
|
||||
"status": upload_task.status.value,
|
||||
"total_files": upload_task.total_files,
|
||||
"processed_files": upload_task.processed_files,
|
||||
"successful_files": upload_task.successful_files,
|
||||
"failed_files": upload_task.failed_files,
|
||||
"created_at": upload_task.created_at,
|
||||
"updated_at": upload_task.updated_at,
|
||||
"files": file_statuses
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup process pool"""
|
||||
if hasattr(self, 'process_pool'):
|
||||
self.process_pool.shutdown(wait=True)
|
||||
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
149
src/utils/document_processing.py
Normal file
149
src/utils/document_processing.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from docling.document_converter import DocumentConverter
|
||||
from .gpu_detection import detect_gpu_devices
|
||||
|
||||
# Global converter cache for worker processes
|
||||
_worker_converter = None
|
||||
|
||||
def get_worker_converter():
|
||||
"""Get or create a DocumentConverter instance for this worker process"""
|
||||
global _worker_converter
|
||||
if _worker_converter is None:
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
# Configure GPU settings for this worker
|
||||
has_gpu_devices, _ = detect_gpu_devices()
|
||||
if not has_gpu_devices:
|
||||
# Force CPU-only mode in subprocess
|
||||
os.environ['USE_CPU_ONLY'] = 'true'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
||||
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '0'
|
||||
|
||||
# Try to disable CUDA in torch if available
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.is_available = lambda: False
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# GPU mode - let libraries use GPU if available
|
||||
os.environ.pop('USE_CPU_ONLY', None)
|
||||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars
|
||||
|
||||
print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})")
|
||||
_worker_converter = DocumentConverter()
|
||||
print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})")
|
||||
|
||||
return _worker_converter
|
||||
|
||||
def extract_relevant(doc_dict: dict) -> dict:
|
||||
"""
|
||||
Given the full export_to_dict() result:
|
||||
- Grabs origin metadata (hash, filename, mimetype)
|
||||
- Finds every text fragment in `texts`, groups them by page_no
|
||||
- Flattens tables in `tables` into tab-separated text, grouping by row
|
||||
- Concatenates each page's fragments and each table into its own chunk
|
||||
Returns a slimmed dict ready for indexing, with each chunk under "text".
|
||||
"""
|
||||
origin = doc_dict.get("origin", {})
|
||||
chunks = []
|
||||
|
||||
# 1) process free-text fragments
|
||||
page_texts = defaultdict(list)
|
||||
for txt in doc_dict.get("texts", []):
|
||||
prov = txt.get("prov", [])
|
||||
page_no = prov[0].get("page_no") if prov else None
|
||||
if page_no is not None:
|
||||
page_texts[page_no].append(txt.get("text", "").strip())
|
||||
|
||||
for page in sorted(page_texts):
|
||||
chunks.append({
|
||||
"page": page,
|
||||
"type": "text",
|
||||
"text": "\n".join(page_texts[page])
|
||||
})
|
||||
|
||||
# 2) process tables
|
||||
for t_idx, table in enumerate(doc_dict.get("tables", [])):
|
||||
prov = table.get("prov", [])
|
||||
page_no = prov[0].get("page_no") if prov else None
|
||||
|
||||
# group cells by their row index
|
||||
rows = defaultdict(list)
|
||||
for cell in table.get("data").get("table_cells", []):
|
||||
r = cell.get("start_row_offset_idx")
|
||||
c = cell.get("start_col_offset_idx")
|
||||
text = cell.get("text", "").strip()
|
||||
rows[r].append((c, text))
|
||||
|
||||
# build a tab‑separated line for each row, in order
|
||||
flat_rows = []
|
||||
for r in sorted(rows):
|
||||
cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])]
|
||||
flat_rows.append("\t".join(cells))
|
||||
|
||||
chunks.append({
|
||||
"page": page_no,
|
||||
"type": "table",
|
||||
"table_index": t_idx,
|
||||
"text": "\n".join(flat_rows)
|
||||
})
|
||||
|
||||
return {
|
||||
"id": origin.get("binary_hash"),
|
||||
"filename": origin.get("filename"),
|
||||
"mimetype": origin.get("mimetype"),
|
||||
"chunks": chunks
|
||||
}
|
||||
|
||||
def process_document_sync(file_path: str):
|
||||
"""Synchronous document processing function for multiprocessing"""
|
||||
from collections import defaultdict
|
||||
|
||||
# Get the cached converter for this worker
|
||||
converter = get_worker_converter()
|
||||
|
||||
# Compute file hash
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1 << 20)
|
||||
if not chunk:
|
||||
break
|
||||
sha256.update(chunk)
|
||||
file_hash = sha256.hexdigest()
|
||||
|
||||
# Convert with docling
|
||||
result = converter.convert(file_path)
|
||||
full_doc = result.document.export_to_dict()
|
||||
|
||||
# Extract relevant content (same logic as extract_relevant)
|
||||
origin = full_doc.get("origin", {})
|
||||
texts = full_doc.get("texts", [])
|
||||
|
||||
page_texts = defaultdict(list)
|
||||
for txt in texts:
|
||||
prov = txt.get("prov", [])
|
||||
page_no = prov[0].get("page_no") if prov else None
|
||||
if page_no is not None:
|
||||
page_texts[page_no].append(txt.get("text", "").strip())
|
||||
|
||||
chunks = []
|
||||
for page in sorted(page_texts):
|
||||
joined = "\n".join(page_texts[page])
|
||||
chunks.append({
|
||||
"page": page,
|
||||
"text": joined
|
||||
})
|
||||
|
||||
return {
|
||||
"id": file_hash,
|
||||
"filename": origin.get("filename"),
|
||||
"mimetype": origin.get("mimetype"),
|
||||
"chunks": chunks,
|
||||
"file_path": file_path
|
||||
}
|
||||
34
src/utils/gpu_detection.py
Normal file
34
src/utils/gpu_detection.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
|
||||
def detect_gpu_devices():
|
||||
"""Detect if GPU devices are actually available"""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
return True, torch.cuda.device_count()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return True, "detected"
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return False, 0
|
||||
|
||||
def get_worker_count():
|
||||
"""Get optimal worker count based on GPU availability"""
|
||||
has_gpu_devices, gpu_count = detect_gpu_devices()
|
||||
|
||||
if has_gpu_devices:
|
||||
default_workers = min(4, multiprocessing.cpu_count() // 2)
|
||||
print(f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)")
|
||||
else:
|
||||
default_workers = multiprocessing.cpu_count()
|
||||
print(f"CPU-only mode enabled - using full concurrency ({default_workers} workers)")
|
||||
|
||||
return int(os.getenv("MAX_WORKERS", default_workers))
|
||||
Loading…
Add table
Reference in a new issue