From c9182184cf9b4d01bcecbdc3df6be22352445bb9 Mon Sep 17 00:00:00 2001 From: phact Date: Wed, 30 Jul 2025 11:18:19 -0400 Subject: [PATCH] refactor --- src/api/__init__.py | 0 src/api/auth.py | 80 +++++++++++ src/api/chat.py | 59 ++++++++ src/api/connectors.py | 81 +++++++++++ src/api/search.py | 13 ++ src/api/upload.py | 78 +++++++++++ src/config/__init__.py | 0 src/config/settings.py | 105 ++++++++++++++ src/main.py | 234 +++++++++++++++++++++++++++++++ src/models/__init__.py | 0 src/models/tasks.py | 32 +++++ src/services/__init__.py | 0 src/services/auth_service.py | 213 ++++++++++++++++++++++++++++ src/services/chat_service.py | 47 +++++++ src/services/document_service.py | 184 ++++++++++++++++++++++++ src/services/search_service.py | 80 +++++++++++ src/services/task_service.py | 112 +++++++++++++++ src/utils/__init__.py | 0 src/utils/document_processing.py | 149 ++++++++++++++++++++ src/utils/gpu_detection.py | 34 +++++ 20 files changed, 1501 insertions(+) create mode 100644 src/api/__init__.py create mode 100644 src/api/auth.py create mode 100644 src/api/chat.py create mode 100644 src/api/connectors.py create mode 100644 src/api/search.py create mode 100644 src/api/upload.py create mode 100644 src/config/__init__.py create mode 100644 src/config/settings.py create mode 100644 src/main.py create mode 100644 src/models/__init__.py create mode 100644 src/models/tasks.py create mode 100644 src/services/__init__.py create mode 100644 src/services/auth_service.py create mode 100644 src/services/chat_service.py create mode 100644 src/services/document_service.py create mode 100644 src/services/search_service.py create mode 100644 src/services/task_service.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/document_processing.py create mode 100644 src/utils/gpu_detection.py diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/api/auth.py b/src/api/auth.py new file mode 100644 index 00000000..5c7c7275 --- /dev/null +++ b/src/api/auth.py @@ -0,0 +1,80 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse + +async def auth_init(request: Request, auth_service, session_manager): + """Initialize OAuth flow for authentication or data source connection""" + try: + data = await request.json() + provider = data.get("provider") + purpose = data.get("purpose", "data_source") + connection_name = data.get("name", f"{provider}_{purpose}") + redirect_uri = data.get("redirect_uri") + + user = getattr(request.state, 'user', None) + user_id = user.user_id if user else None + + result = await auth_service.init_oauth( + provider, purpose, connection_name, redirect_uri, user_id + ) + return JSONResponse(result) + + except Exception as e: + import traceback + traceback.print_exc() + return JSONResponse({"error": f"Failed to initialize OAuth: {str(e)}"}, status_code=500) + +async def auth_callback(request: Request, auth_service, session_manager): + """Handle OAuth callback - exchange authorization code for tokens""" + try: + data = await request.json() + connection_id = data.get("connection_id") + authorization_code = data.get("authorization_code") + state = data.get("state") + + result = await auth_service.handle_oauth_callback( + connection_id, authorization_code, state + ) + + # If this is app auth, set JWT cookie + if result.get("purpose") == "app_auth" and result.get("jwt_token"): + response = JSONResponse({ + k: v for k, v in result.items() if k != "jwt_token" + }) + response.set_cookie( + key="auth_token", + value=result["jwt_token"], + httponly=True, + secure=False, + samesite="lax", + max_age=7 * 24 * 60 * 60 # 7 days + ) + return response + else: + return JSONResponse(result) + + except Exception as e: + import traceback + traceback.print_exc() + return JSONResponse({"error": f"Callback failed: {str(e)}"}, status_code=500) + +async def auth_me(request: Request, auth_service, session_manager): + """Get current user information""" + result = await auth_service.get_user_info(request) + return JSONResponse(result) + +async def auth_logout(request: Request, auth_service, session_manager): + """Logout user by clearing auth cookie""" + response = JSONResponse({ + "status": "logged_out", + "message": "Successfully logged out" + }) + + # Clear the auth cookie + response.delete_cookie( + key="auth_token", + httponly=True, + secure=False, + samesite="lax" + ) + + return response \ No newline at end of file diff --git a/src/api/chat.py b/src/api/chat.py new file mode 100644 index 00000000..6dd8a790 --- /dev/null +++ b/src/api/chat.py @@ -0,0 +1,59 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse, StreamingResponse + +async def chat_endpoint(request: Request, chat_service, session_manager): + """Handle chat requests""" + data = await request.json() + prompt = data.get("prompt", "") + previous_response_id = data.get("previous_response_id") + stream = data.get("stream", False) + + user = request.state.user + user_id = user.user_id + + if not prompt: + return JSONResponse({"error": "Prompt is required"}, status_code=400) + + if stream: + return StreamingResponse( + await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + else: + result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False) + return JSONResponse(result) + +async def langflow_endpoint(request: Request, chat_service, session_manager): + """Handle Langflow chat requests""" + data = await request.json() + prompt = data.get("prompt", "") + previous_response_id = data.get("previous_response_id") + stream = data.get("stream", False) + + if not prompt: + return JSONResponse({"error": "Prompt is required"}, status_code=400) + + try: + if stream: + return StreamingResponse( + await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=True), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + else: + result = await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=False) + return JSONResponse(result) + + except Exception as e: + return JSONResponse({"error": f"Langflow request failed: {str(e)}"}, status_code=500) \ No newline at end of file diff --git a/src/api/connectors.py b/src/api/connectors.py new file mode 100644 index 00000000..20c71f08 --- /dev/null +++ b/src/api/connectors.py @@ -0,0 +1,81 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse + +async def connector_sync(request: Request, connector_service, session_manager): + """Sync files from a connector connection""" + data = await request.json() + connection_id = data.get("connection_id") + max_files = data.get("max_files") + + if not connection_id: + return JSONResponse({"error": "connection_id is required"}, status_code=400) + + try: + print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}") + + # Verify user owns this connection + user = request.state.user + print(f"[DEBUG] User: {user.user_id}") + + connection_config = await connector_service.connection_manager.get_connection(connection_id) + print(f"[DEBUG] Got connection config: {connection_config is not None}") + + if not connection_config: + return JSONResponse({"error": "Connection not found"}, status_code=404) + + if connection_config.user_id != user.user_id: + return JSONResponse({"error": "Access denied"}, status_code=403) + + print(f"[DEBUG] About to call sync_connector_files") + task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files) + print(f"[DEBUG] Got task_id: {task_id}") + + return JSONResponse({ + "task_id": task_id, + "status": "sync_started", + "message": f"Started syncing files from connection {connection_id}" + }, + status_code=201 + ) + + except Exception as e: + import sys + import traceback + + error_msg = f"[ERROR] Connector sync failed: {str(e)}" + print(error_msg, file=sys.stderr, flush=True) + traceback.print_exc(file=sys.stderr) + sys.stderr.flush() + + return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500) + +async def connector_status(request: Request, connector_service, session_manager): + """Get connector status for authenticated user""" + connector_type = request.path_params.get("connector_type", "google_drive") + user = request.state.user + + # Get connections for this connector type and user + connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, + connector_type=connector_type + ) + + # Check if there are any active connections + active_connections = [conn for conn in connections if conn.is_active] + has_authenticated_connection = len(active_connections) > 0 + + return JSONResponse({ + "connector_type": connector_type, + "authenticated": has_authenticated_connection, + "status": "connected" if has_authenticated_connection else "not_connected", + "connections": [ + { + "connection_id": conn.connection_id, + "name": conn.name, + "is_active": conn.is_active, + "created_at": conn.created_at.isoformat(), + "last_sync": conn.last_sync.isoformat() if conn.last_sync else None + } + for conn in connections + ] + }) \ No newline at end of file diff --git a/src/api/search.py b/src/api/search.py new file mode 100644 index 00000000..efa71bb4 --- /dev/null +++ b/src/api/search.py @@ -0,0 +1,13 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse + +async def search(request: Request, search_service, session_manager): + """Search for documents""" + payload = await request.json() + query = payload.get("query") + if not query: + return JSONResponse({"error": "Query is required"}, status_code=400) + + user = request.state.user + result = await search_service.search(query, user_id=user.user_id) + return JSONResponse(result) \ No newline at end of file diff --git a/src/api/upload.py b/src/api/upload.py new file mode 100644 index 00000000..f9f9224d --- /dev/null +++ b/src/api/upload.py @@ -0,0 +1,78 @@ +import os +from starlette.requests import Request +from starlette.responses import JSONResponse + +async def upload(request: Request, document_service, session_manager): + """Upload a single file""" + form = await request.form() + upload_file = form["file"] + user = request.state.user + + result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id) + return JSONResponse(result) + +async def upload_path(request: Request, task_service, session_manager): + """Upload all files from a directory path""" + payload = await request.json() + base_dir = payload.get("path") + if not base_dir or not os.path.isdir(base_dir): + return JSONResponse({"error": "Invalid path"}, status_code=400) + + file_paths = [os.path.join(root, fn) + for root, _, files in os.walk(base_dir) + for fn in files] + + if not file_paths: + return JSONResponse({"error": "No files found in directory"}, status_code=400) + + user = request.state.user + task_id = await task_service.create_upload_task(user.user_id, file_paths) + + return JSONResponse({ + "task_id": task_id, + "total_files": len(file_paths), + "status": "accepted" + }, status_code=201) + +async def upload_context(request: Request, document_service, chat_service, session_manager): + """Upload a file and add its content as context to the current conversation""" + form = await request.form() + upload_file = form["file"] + filename = upload_file.filename or "uploaded_document" + + # Get optional parameters + previous_response_id = form.get("previous_response_id") + endpoint = form.get("endpoint", "langflow") + + # Process document and extract content + doc_result = await document_service.process_upload_context(upload_file, filename) + + # Send document content as user message to get proper response_id + response_text, response_id = await chat_service.upload_context_chat( + doc_result["content"], + filename, + previous_response_id=previous_response_id, + endpoint=endpoint + ) + + response_data = { + "status": "context_added", + "filename": doc_result["filename"], + "pages": doc_result["pages"], + "content_length": doc_result["content_length"], + "response_id": response_id, + "confirmation": response_text + } + + return JSONResponse(response_data) + +async def task_status(request: Request, task_service, session_manager): + """Get the status of an upload task""" + task_id = request.path_params.get("task_id") + user = request.state.user + + task_status_result = task_service.get_task_status(user.user_id, task_id) + if not task_status_result: + return JSONResponse({"error": "Task not found"}, status_code=404) + + return JSONResponse(task_status_result) \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 00000000..e1cc15c5 --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,105 @@ +import os +from dotenv import load_dotenv +from opensearchpy import AsyncOpenSearch +from opensearchpy._async.http_aiohttp import AIOHttpConnection +from docling.document_converter import DocumentConverter +from agentd.patch import patch_openai_with_mcp +from openai import AsyncOpenAI + +load_dotenv() +load_dotenv("../") + +# Environment variables +OPENSEARCH_HOST = os.getenv("OPENSEARCH_HOST", "localhost") +OPENSEARCH_PORT = int(os.getenv("OPENSEARCH_PORT", "9200")) +OPENSEARCH_USERNAME = os.getenv("OPENSEARCH_USERNAME", "admin") +OPENSEARCH_PASSWORD = os.getenv("OPENSEARCH_PASSWORD") +LANGFLOW_URL = os.getenv("LANGFLOW_URL", "http://localhost:7860") +FLOW_ID = os.getenv("FLOW_ID") +LANGFLOW_KEY = os.getenv("LANGFLOW_SECRET_KEY") +SESSION_SECRET = os.getenv("SESSION_SECRET", "your-secret-key-change-in-production") +GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID") +GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") + +# OpenSearch configuration +INDEX_NAME = "documents" +VECTOR_DIM = 1536 +EMBED_MODEL = "text-embedding-3-small" + +INDEX_BODY = { + "settings": { + "index": {"knn": True}, + "number_of_shards": 1, + "number_of_replicas": 1 + }, + "mappings": { + "properties": { + "document_id": { "type": "keyword" }, + "filename": { "type": "keyword" }, + "mimetype": { "type": "keyword" }, + "page": { "type": "integer" }, + "text": { "type": "text" }, + "chunk_embedding": { + "type": "knn_vector", + "dimension": VECTOR_DIM, + "method": { + "name": "disk_ann", + "engine": "jvector", + "space_type": "l2", + "parameters": { + "ef_construction": 100, + "m": 16 + } + } + }, + "source_url": { "type": "keyword" }, + "connector_type": { "type": "keyword" }, + "owner": { "type": "keyword" }, + "allowed_users": { "type": "keyword" }, + "allowed_groups": { "type": "keyword" }, + "user_permissions": { "type": "object" }, + "group_permissions": { "type": "object" }, + "created_time": { "type": "date" }, + "modified_time": { "type": "date" }, + "indexed_time": { "type": "date" }, + "metadata": { "type": "object" } + } + } +} + +class AppClients: + def __init__(self): + self.opensearch = None + self.langflow_client = None + self.patched_async_client = None + self.converter = None + + def initialize(self): + # Initialize OpenSearch client + self.opensearch = AsyncOpenSearch( + hosts=[{"host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT}], + connection_class=AIOHttpConnection, + scheme="https", + use_ssl=True, + verify_certs=False, + ssl_assert_fingerprint=None, + http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), + http_compress=True, + ) + + # Initialize Langflow client + self.langflow_client = AsyncOpenAI( + base_url=f"{LANGFLOW_URL}/api/v1", + api_key=LANGFLOW_KEY + ) + + # Initialize patched OpenAI client + self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) + + # Initialize Docling converter + self.converter = DocumentConverter() + + return self + +# Global clients instance +clients = AppClients() \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 00000000..7e3e23db --- /dev/null +++ b/src/main.py @@ -0,0 +1,234 @@ +import asyncio +import atexit +import torch +from functools import partial +from starlette.applications import Starlette +from starlette.routing import Route + +# Configuration and setup +from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET +from utils.gpu_detection import detect_gpu_devices + +# Services +from services.document_service import DocumentService +from services.search_service import SearchService +from services.task_service import TaskService +from services.auth_service import AuthService +from services.chat_service import ChatService + +# Existing services +from connectors.service import ConnectorService +from session_manager import SessionManager +from auth_middleware import require_auth, optional_auth + +# API endpoints +from api import upload, search, chat, auth, connectors + +print("CUDA available:", torch.cuda.is_available()) +print("CUDA version PyTorch was built with:", torch.version.cuda) + +async def wait_for_opensearch(): + """Wait for OpenSearch to be ready with retries""" + max_retries = 30 + retry_delay = 2 + + for attempt in range(max_retries): + try: + await clients.opensearch.info() + print("OpenSearch is ready!") + return + except Exception as e: + print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})") + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + else: + raise Exception("OpenSearch failed to become ready") + +async def init_index(): + """Initialize OpenSearch index""" + await wait_for_opensearch() + + if not await clients.opensearch.indices.exists(index=INDEX_NAME): + await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY) + print(f"Created index '{INDEX_NAME}'") + else: + print(f"Index '{INDEX_NAME}' already exists, skipping creation.") + +def initialize_services(): + """Initialize all services and their dependencies""" + # Initialize clients + clients.initialize() + + # Initialize session manager + session_manager = SessionManager(SESSION_SECRET) + + # Initialize services + document_service = DocumentService() + search_service = SearchService() + task_service = TaskService(document_service) + chat_service = ChatService() + + # Set process pool for document service + document_service.process_pool = task_service.process_pool + + # Initialize connector service + connector_service = ConnectorService( + opensearch_client=clients.opensearch, + patched_async_client=clients.patched_async_client, + process_pool=task_service.process_pool, + embed_model="text-embedding-3-small", + index_name=INDEX_NAME + ) + + # Initialize auth service + auth_service = AuthService(session_manager, connector_service) + + return { + 'document_service': document_service, + 'search_service': search_service, + 'task_service': task_service, + 'chat_service': chat_service, + 'auth_service': auth_service, + 'connector_service': connector_service, + 'session_manager': session_manager + } + +def create_app(): + """Create and configure the Starlette application""" + services = initialize_services() + + # Create route handlers with service dependencies injected + routes = [ + # Upload endpoints + Route("/upload", + require_auth(services['session_manager'])( + partial(upload.upload, + document_service=services['document_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/upload_context", + require_auth(services['session_manager'])( + partial(upload.upload_context, + document_service=services['document_service'], + chat_service=services['chat_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/upload_path", + require_auth(services['session_manager'])( + partial(upload.upload_path, + task_service=services['task_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/tasks/{task_id}", + require_auth(services['session_manager'])( + partial(upload.task_status, + task_service=services['task_service'], + session_manager=services['session_manager']) + ), methods=["GET"]), + + # Search endpoint + Route("/search", + require_auth(services['session_manager'])( + partial(search.search, + search_service=services['search_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + # Chat endpoints + Route("/chat", + require_auth(services['session_manager'])( + partial(chat.chat_endpoint, + chat_service=services['chat_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/langflow", + require_auth(services['session_manager'])( + partial(chat.langflow_endpoint, + chat_service=services['chat_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + # Authentication endpoints + Route("/auth/init", + optional_auth(services['session_manager'])( + partial(auth.auth_init, + auth_service=services['auth_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/auth/callback", + partial(auth.auth_callback, + auth_service=services['auth_service'], + session_manager=services['session_manager']), + methods=["POST"]), + + Route("/auth/me", + optional_auth(services['session_manager'])( + partial(auth.auth_me, + auth_service=services['auth_service'], + session_manager=services['session_manager']) + ), methods=["GET"]), + + Route("/auth/logout", + require_auth(services['session_manager'])( + partial(auth.auth_logout, + auth_service=services['auth_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + # Connector endpoints + Route("/connectors/sync", + require_auth(services['session_manager'])( + partial(connectors.connector_sync, + connector_service=services['connector_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/connectors/status/{connector_type}", + require_auth(services['session_manager'])( + partial(connectors.connector_status, + connector_service=services['connector_service'], + session_manager=services['session_manager']) + ), methods=["GET"]), + ] + + app = Starlette(debug=True, routes=routes) + app.state.services = services # Store services for cleanup + + return app + +async def startup(): + """Application startup tasks""" + await init_index() + # Get services from app state if needed for initialization + # services = app.state.services + # await services['connector_service'].initialize() + +def cleanup(): + """Cleanup on application shutdown""" + # This will be called on exit to cleanup process pools + pass + +if __name__ == "__main__": + import uvicorn + + # Register cleanup function + atexit.register(cleanup) + + # Create app + app = create_app() + + # Run startup tasks + asyncio.run(startup()) + + # Run the server + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + reload=False, # Disable reload since we're running from main + ) \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/tasks.py b/src/models/tasks.py new file mode 100644 index 00000000..3f9dfb97 --- /dev/null +++ b/src/models/tasks.py @@ -0,0 +1,32 @@ +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Optional + +class TaskStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + +@dataclass +class FileTask: + file_path: str + status: TaskStatus = TaskStatus.PENDING + result: Optional[dict] = None + error: Optional[str] = None + retry_count: int = 0 + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + +@dataclass +class UploadTask: + task_id: str + total_files: int + processed_files: int = 0 + successful_files: int = 0 + failed_files: int = 0 + file_tasks: Dict[str, FileTask] = field(default_factory=dict) + status: TaskStatus = TaskStatus.PENDING + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) \ No newline at end of file diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/services/auth_service.py b/src/services/auth_service.py new file mode 100644 index 00000000..a95d5538 --- /dev/null +++ b/src/services/auth_service.py @@ -0,0 +1,213 @@ +import os +import uuid +import json +import httpx +import aiofiles +from datetime import datetime, timedelta +from typing import Optional + +from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET +from session_manager import SessionManager + +class AuthService: + def __init__(self, session_manager: SessionManager, connector_service=None): + self.session_manager = session_manager + self.connector_service = connector_service + self.used_auth_codes = set() # Track used authorization codes + + async def init_oauth(self, provider: str, purpose: str, connection_name: str, + redirect_uri: str, user_id: str = None) -> dict: + """Initialize OAuth flow for authentication or data source connection""" + if provider != "google": + raise ValueError("Unsupported provider") + + if not redirect_uri: + raise ValueError("redirect_uri is required") + + if not GOOGLE_OAUTH_CLIENT_ID: + raise ValueError("Google OAuth client ID not configured") + + # Create connection configuration + token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json" + config = { + "client_id": GOOGLE_OAUTH_CLIENT_ID, + "token_file": token_file, + "provider": provider, + "purpose": purpose, + "redirect_uri": redirect_uri + } + + # Create connection in manager + connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth" + connection_id = await self.connector_service.connection_manager.create_connection( + connector_type=connector_type, + name=connection_name, + config=config, + user_id=user_id + ) + + # Return OAuth configuration for client-side flow + scopes = [ + 'openid', 'email', 'profile', + 'https://www.googleapis.com/auth/drive.readonly', + 'https://www.googleapis.com/auth/drive.metadata.readonly' + ] + + oauth_config = { + "client_id": GOOGLE_OAUTH_CLIENT_ID, + "scopes": scopes, + "redirect_uri": redirect_uri, + "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", + "token_endpoint": "https://oauth2.googleapis.com/token" + } + + return { + "connection_id": connection_id, + "oauth_config": oauth_config + } + + async def handle_oauth_callback(self, connection_id: str, authorization_code: str, + state: str = None) -> dict: + """Handle OAuth callback - exchange authorization code for tokens""" + if not all([connection_id, authorization_code]): + raise ValueError("Missing required parameters (connection_id, authorization_code)") + + # Check if authorization code has already been used + if authorization_code in self.used_auth_codes: + raise ValueError("Authorization code already used") + + # Mark code as used to prevent duplicate requests + self.used_auth_codes.add(authorization_code) + + try: + # Get connection config + connection_config = await self.connector_service.connection_manager.get_connection(connection_id) + if not connection_config: + raise ValueError("Connection not found") + + # Exchange authorization code for tokens + redirect_uri = connection_config.config.get("redirect_uri") + if not redirect_uri: + raise ValueError("Redirect URI not found in connection config") + + token_url = "https://oauth2.googleapis.com/token" + token_payload = { + "code": authorization_code, + "client_id": connection_config.config["client_id"], + "client_secret": GOOGLE_OAUTH_CLIENT_SECRET, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code" + } + + async with httpx.AsyncClient() as client: + token_response = await client.post(token_url, data=token_payload) + + if token_response.status_code != 200: + raise Exception(f"Token exchange failed: {token_response.text}") + + token_data = token_response.json() + + # Store tokens in the token file + token_file_data = { + "token": token_data["access_token"], + "refresh_token": token_data.get("refresh_token"), + "scopes": [ + "openid", "email", "profile", + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/drive.metadata.readonly" + ] + } + + # Add expiry if provided + if token_data.get("expires_in"): + expiry = datetime.now() + timedelta(seconds=int(token_data["expires_in"])) + token_file_data["expiry"] = expiry.isoformat() + + # Save tokens to file + token_file_path = connection_config.config["token_file"] + async with aiofiles.open(token_file_path, 'w') as f: + await f.write(json.dumps(token_file_data, indent=2)) + + # Route based on purpose + purpose = connection_config.config.get("purpose", "data_source") + + if purpose == "app_auth": + return await self._handle_app_auth(connection_id, connection_config, token_data) + else: + return await self._handle_data_source_auth(connection_id, connection_config) + + except Exception as e: + # Remove used code from set if we failed + self.used_auth_codes.discard(authorization_code) + raise e + + async def _handle_app_auth(self, connection_id: str, connection_config, token_data: dict) -> dict: + """Handle app authentication - create user session""" + jwt_token = await self.session_manager.create_user_session(token_data["access_token"]) + + if jwt_token: + # Get the user info to create a persistent Google Drive connection + user_info = await self.session_manager.get_user_info_from_token(token_data["access_token"]) + user_id = user_info["id"] if user_info else None + + response_data = { + "status": "authenticated", + "purpose": "app_auth", + "redirect": "/", + "jwt_token": jwt_token # Include JWT token in response + } + + if user_id: + # Convert the temporary auth connection to a persistent Google Drive connection + await self.connector_service.connection_manager.update_connection( + connection_id=connection_id, + connector_type="google_drive", + name=f"Google Drive ({user_info.get('email', 'Unknown')})", + user_id=user_id, + config={ + **connection_config.config, + "purpose": "data_source", + "user_email": user_info.get("email") + } + ) + response_data["google_drive_connection_id"] = connection_id + else: + # Fallback: delete connection if we can't get user info + await self.connector_service.connection_manager.delete_connection(connection_id) + + return response_data + else: + # Clean up connection if session creation failed + await self.connector_service.connection_manager.delete_connection(connection_id) + raise Exception("Failed to create user session") + + async def _handle_data_source_auth(self, connection_id: str, connection_config) -> dict: + """Handle data source connection - keep the connection for syncing""" + return { + "status": "authenticated", + "connection_id": connection_id, + "purpose": "data_source", + "connector_type": connection_config.connector_type + } + + async def get_user_info(self, request) -> Optional[dict]: + """Get current user information from request""" + user = getattr(request.state, 'user', None) + + if user: + return { + "authenticated": True, + "user": { + "user_id": user.user_id, + "email": user.email, + "name": user.name, + "picture": user.picture, + "provider": user.provider, + "last_login": user.last_login.isoformat() if user.last_login else None + } + } + else: + return { + "authenticated": False, + "user": None + } \ No newline at end of file diff --git a/src/services/chat_service.py b/src/services/chat_service.py new file mode 100644 index 00000000..23105b6f --- /dev/null +++ b/src/services/chat_service.py @@ -0,0 +1,47 @@ +from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY +from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream + +class ChatService: + + async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False): + """Handle chat requests using the patched OpenAI client""" + if not prompt: + raise ValueError("Prompt is required") + + if stream: + return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id) + else: + response_text, response_id = await async_chat(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id) + response_data = {"response": response_text} + if response_id: + response_data["response_id"] = response_id + return response_data + + async def langflow_chat(self, prompt: str, previous_response_id: str = None, stream: bool = False): + """Handle Langflow chat requests""" + if not prompt: + raise ValueError("Prompt is required") + + if not LANGFLOW_URL or not FLOW_ID or not LANGFLOW_KEY: + raise ValueError("LANGFLOW_URL, FLOW_ID, and LANGFLOW_KEY environment variables are required") + + if stream: + return async_langflow_stream(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id) + else: + response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id) + response_data = {"response": response_text} + if response_id: + response_data["response_id"] = response_id + return response_data + + async def upload_context_chat(self, document_content: str, filename: str, + previous_response_id: str = None, endpoint: str = "langflow"): + """Send document content as user message to get proper response_id""" + document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{document_content}\n\nPlease confirm you've received this document and are ready to answer questions about it." + + if endpoint == "langflow": + response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, document_prompt, previous_response_id=previous_response_id) + else: # chat + response_text, response_id = await async_chat(clients.patched_async_client, document_prompt, previous_response_id=previous_response_id) + + return response_text, response_id \ No newline at end of file diff --git a/src/services/document_service.py b/src/services/document_service.py new file mode 100644 index 00000000..1d467dcf --- /dev/null +++ b/src/services/document_service.py @@ -0,0 +1,184 @@ +import datetime +import hashlib +import tempfile +import os +import aiofiles +from io import BytesIO +from docling_core.types.io import DocumentStream + +from config.settings import clients, INDEX_NAME, EMBED_MODEL +from utils.document_processing import extract_relevant, process_document_sync + +class DocumentService: + def __init__(self, process_pool=None): + self.process_pool = process_pool + + async def process_file_common(self, file_path: str, file_hash: str = None, owner_user_id: str = None): + """ + Common processing logic for both upload and upload_path. + 1. Optionally compute SHA256 hash if not provided. + 2. Convert with docling and extract relevant content. + 3. Add embeddings. + 4. Index into OpenSearch. + """ + if file_hash is None: + sha256 = hashlib.sha256() + async with aiofiles.open(file_path, "rb") as f: + while True: + chunk = await f.read(1 << 20) + if not chunk: + break + sha256.update(chunk) + file_hash = sha256.hexdigest() + + exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash) + if exists: + return {"status": "unchanged", "id": file_hash} + + # convert and extract + result = clients.converter.convert(file_path) + full_doc = result.document.export_to_dict() + slim_doc = extract_relevant(full_doc) + + texts = [c["text"] for c in slim_doc["chunks"]] + resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) + embeddings = [d.embedding for d in resp.data] + + # Index each chunk as a separate document + for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): + chunk_doc = { + "document_id": file_hash, + "filename": slim_doc["filename"], + "mimetype": slim_doc["mimetype"], + "page": chunk["page"], + "text": chunk["text"], + "chunk_embedding": vect, + "owner": owner_user_id, + "indexed_time": datetime.datetime.now().isoformat() + } + chunk_id = f"{file_hash}_{i}" + await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) + return {"status": "indexed", "id": file_hash} + + async def process_upload_file(self, upload_file, owner_user_id: str = None): + """Process an uploaded file from form data""" + sha256 = hashlib.sha256() + tmp = tempfile.NamedTemporaryFile(delete=False) + try: + while True: + chunk = await upload_file.read(1 << 20) + if not chunk: + break + sha256.update(chunk) + tmp.write(chunk) + tmp.flush() + + file_hash = sha256.hexdigest() + exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash) + if exists: + return {"status": "unchanged", "id": file_hash} + + result = await self.process_file_common(tmp.name, file_hash, owner_user_id=owner_user_id) + return result + + finally: + tmp.close() + os.remove(tmp.name) + + async def process_upload_context(self, upload_file, filename: str = None): + """Process uploaded file and return content for context""" + import io + + if not filename: + filename = upload_file.filename or "uploaded_document" + + # Stream file content into BytesIO + content = io.BytesIO() + while True: + chunk = await upload_file.read(1 << 20) # 1MB chunks + if not chunk: + break + content.write(chunk) + content.seek(0) # Reset to beginning for reading + + # Create DocumentStream and process with docling + doc_stream = DocumentStream(name=filename, stream=content) + result = clients.converter.convert(doc_stream) + full_doc = result.document.export_to_dict() + slim_doc = extract_relevant(full_doc) + + # Extract all text content + all_text = [] + for chunk in slim_doc["chunks"]: + all_text.append(f"Page {chunk['page']}:\n{chunk['text']}") + + full_content = "\n\n".join(all_text) + + return { + "filename": filename, + "content": full_content, + "pages": len(slim_doc["chunks"]), + "content_length": len(full_content) + } + + async def process_single_file_task(self, upload_task, file_path: str): + """Process a single file and update task tracking - used by task service""" + from models.tasks import TaskStatus + import time + import asyncio + + file_task = upload_task.file_tasks[file_path] + file_task.status = TaskStatus.RUNNING + file_task.updated_at = time.time() + + try: + # Check if file already exists in index + loop = asyncio.get_event_loop() + + # Run CPU-intensive docling processing in separate process + slim_doc = await loop.run_in_executor(self.process_pool, process_document_sync, file_path) + + # Check if already indexed + exists = await clients.opensearch.exists(index=INDEX_NAME, id=slim_doc["id"]) + if exists: + result = {"status": "unchanged", "id": slim_doc["id"]} + else: + # Generate embeddings and index (I/O bound, keep in main process) + texts = [c["text"] for c in slim_doc["chunks"]] + resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) + embeddings = [d.embedding for d in resp.data] + + # Index each chunk + for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): + chunk_doc = { + "document_id": slim_doc["id"], + "filename": slim_doc["filename"], + "mimetype": slim_doc["mimetype"], + "page": chunk["page"], + "text": chunk["text"], + "chunk_embedding": vect + } + chunk_id = f"{slim_doc['id']}_{i}" + await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) + + result = {"status": "indexed", "id": slim_doc["id"]} + + result["path"] = file_path + file_task.status = TaskStatus.COMPLETED + file_task.result = result + upload_task.successful_files += 1 + + except Exception as e: + print(f"[ERROR] Failed to process file {file_path}: {e}") + import traceback + traceback.print_exc() + file_task.status = TaskStatus.FAILED + file_task.error = str(e) + upload_task.failed_files += 1 + finally: + file_task.updated_at = time.time() + upload_task.processed_files += 1 + upload_task.updated_at = time.time() + + if upload_task.processed_files >= upload_task.total_files: + upload_task.status = TaskStatus.COMPLETED \ No newline at end of file diff --git a/src/services/search_service.py b/src/services/search_service.py new file mode 100644 index 00000000..9e824cfa --- /dev/null +++ b/src/services/search_service.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, Optional +from agentd.tool_decorator import tool +from config.settings import clients, INDEX_NAME, EMBED_MODEL + +class SearchService: + + @tool + async def search_tool(self, query: str, user_id: str = None) -> Dict[str, Any]: + """ + Use this tool to search for documents relevant to the query. + + Args: + query (str): query string to search the corpus + user_id (str): user ID for access control (optional) + + Returns: + dict (str, Any): {"results": [chunks]} on success + """ + # Embed the query + resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query]) + query_embedding = resp.data[0].embedding + + # Base query structure + search_body = { + "query": { + "bool": { + "must": [ + { + "knn": { + "chunk_embedding": { + "vector": query_embedding, + "k": 10 + } + } + } + ] + } + }, + "_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"], + "size": 10 + } + + # Require authentication - no anonymous access to search + if not user_id: + return {"results": [], "error": "Authentication required"} + + # Authenticated user access control + # User can access documents if: + # 1. They own the document (owner field matches user_id) + # 2. They're in allowed_users list + # 3. Document has no ACL (public documents) + # TODO: Add group access control later + should_clauses = [ + {"term": {"owner": user_id}}, + {"term": {"allowed_users": user_id}}, + {"bool": {"must_not": {"exists": {"field": "owner"}}}} # Public docs + ] + + search_body["query"]["bool"]["should"] = should_clauses + search_body["query"]["bool"]["minimum_should_match"] = 1 + + results = await clients.opensearch.search(index=INDEX_NAME, body=search_body) + + # Transform results + chunks = [] + for hit in results["hits"]["hits"]: + chunks.append({ + "filename": hit["_source"]["filename"], + "mimetype": hit["_source"]["mimetype"], + "page": hit["_source"]["page"], + "text": hit["_source"]["text"], + "score": hit["_score"], + "source_url": hit["_source"].get("source_url"), + "owner": hit["_source"].get("owner") + }) + return {"results": chunks} + + async def search(self, query: str, user_id: str = None) -> Dict[str, Any]: + """Public search method for API endpoints""" + return await self.search_tool(query, user_id) \ No newline at end of file diff --git a/src/services/task_service.py b/src/services/task_service.py new file mode 100644 index 00000000..6c7b345b --- /dev/null +++ b/src/services/task_service.py @@ -0,0 +1,112 @@ +import asyncio +import uuid +import time +import random +from typing import Dict +from concurrent.futures import ProcessPoolExecutor + +from models.tasks import TaskStatus, UploadTask, FileTask +from utils.gpu_detection import get_worker_count + +class TaskService: + def __init__(self, document_service=None): + self.document_service = document_service + self.task_store: Dict[str, Dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask} + self.background_tasks = set() + + # Initialize process pool + max_workers = get_worker_count() + self.process_pool = ProcessPoolExecutor(max_workers=max_workers) + print(f"Process pool initialized with {max_workers} workers") + + async def exponential_backoff_delay(self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None: + """Apply exponential backoff with jitter""" + delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay) + await asyncio.sleep(delay) + + async def create_upload_task(self, user_id: str, file_paths: list) -> str: + """Create a new upload task for bulk file processing""" + task_id = str(uuid.uuid4()) + upload_task = UploadTask( + task_id=task_id, + total_files=len(file_paths), + file_tasks={path: FileTask(file_path=path) for path in file_paths} + ) + + if user_id not in self.task_store: + self.task_store[user_id] = {} + self.task_store[user_id][task_id] = upload_task + + # Start background processing + background_task = asyncio.create_task(self.background_upload_processor(user_id, task_id)) + self.background_tasks.add(background_task) + background_task.add_done_callback(self.background_tasks.discard) + + return task_id + + async def background_upload_processor(self, user_id: str, task_id: str) -> None: + """Background task to process all files in an upload job with concurrency control""" + try: + upload_task = self.task_store[user_id][task_id] + upload_task.status = TaskStatus.RUNNING + upload_task.updated_at = time.time() + + # Process files with limited concurrency to avoid overwhelming the system + max_workers = get_worker_count() + semaphore = asyncio.Semaphore(max_workers * 2) # Allow 2x process pool size for async I/O + + async def process_with_semaphore(file_path: str): + async with semaphore: + await self.document_service.process_single_file_task(upload_task, file_path) + + tasks = [ + process_with_semaphore(file_path) + for file_path in upload_task.file_tasks.keys() + ] + + await asyncio.gather(*tasks, return_exceptions=True) + + except Exception as e: + print(f"[ERROR] Background upload processor failed for task {task_id}: {e}") + import traceback + traceback.print_exc() + if user_id in self.task_store and task_id in self.task_store[user_id]: + self.task_store[user_id][task_id].status = TaskStatus.FAILED + self.task_store[user_id][task_id].updated_at = time.time() + + def get_task_status(self, user_id: str, task_id: str) -> dict: + """Get the status of a specific upload task""" + if (not task_id or + user_id not in self.task_store or + task_id not in self.task_store[user_id]): + return None + + upload_task = self.task_store[user_id][task_id] + + file_statuses = {} + for file_path, file_task in upload_task.file_tasks.items(): + file_statuses[file_path] = { + "status": file_task.status.value, + "result": file_task.result, + "error": file_task.error, + "retry_count": file_task.retry_count, + "created_at": file_task.created_at, + "updated_at": file_task.updated_at + } + + return { + "task_id": upload_task.task_id, + "status": upload_task.status.value, + "total_files": upload_task.total_files, + "processed_files": upload_task.processed_files, + "successful_files": upload_task.successful_files, + "failed_files": upload_task.failed_files, + "created_at": upload_task.created_at, + "updated_at": upload_task.updated_at, + "files": file_statuses + } + + def shutdown(self): + """Cleanup process pool""" + if hasattr(self, 'process_pool'): + self.process_pool.shutdown(wait=True) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/document_processing.py b/src/utils/document_processing.py new file mode 100644 index 00000000..d978aad2 --- /dev/null +++ b/src/utils/document_processing.py @@ -0,0 +1,149 @@ +import hashlib +import os +from collections import defaultdict +from docling.document_converter import DocumentConverter +from .gpu_detection import detect_gpu_devices + +# Global converter cache for worker processes +_worker_converter = None + +def get_worker_converter(): + """Get or create a DocumentConverter instance for this worker process""" + global _worker_converter + if _worker_converter is None: + from docling.document_converter import DocumentConverter + + # Configure GPU settings for this worker + has_gpu_devices, _ = detect_gpu_devices() + if not has_gpu_devices: + # Force CPU-only mode in subprocess + os.environ['USE_CPU_ONLY'] = 'true' + os.environ['CUDA_VISIBLE_DEVICES'] = '' + os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' + os.environ['TRANSFORMERS_OFFLINE'] = '0' + os.environ['TORCH_USE_CUDA_DSA'] = '0' + + # Try to disable CUDA in torch if available + try: + import torch + torch.cuda.is_available = lambda: False + except ImportError: + pass + else: + # GPU mode - let libraries use GPU if available + os.environ.pop('USE_CPU_ONLY', None) + os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars + + print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})") + _worker_converter = DocumentConverter() + print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})") + + return _worker_converter + +def extract_relevant(doc_dict: dict) -> dict: + """ + Given the full export_to_dict() result: + - Grabs origin metadata (hash, filename, mimetype) + - Finds every text fragment in `texts`, groups them by page_no + - Flattens tables in `tables` into tab-separated text, grouping by row + - Concatenates each page's fragments and each table into its own chunk + Returns a slimmed dict ready for indexing, with each chunk under "text". + """ + origin = doc_dict.get("origin", {}) + chunks = [] + + # 1) process free-text fragments + page_texts = defaultdict(list) + for txt in doc_dict.get("texts", []): + prov = txt.get("prov", []) + page_no = prov[0].get("page_no") if prov else None + if page_no is not None: + page_texts[page_no].append(txt.get("text", "").strip()) + + for page in sorted(page_texts): + chunks.append({ + "page": page, + "type": "text", + "text": "\n".join(page_texts[page]) + }) + + # 2) process tables + for t_idx, table in enumerate(doc_dict.get("tables", [])): + prov = table.get("prov", []) + page_no = prov[0].get("page_no") if prov else None + + # group cells by their row index + rows = defaultdict(list) + for cell in table.get("data").get("table_cells", []): + r = cell.get("start_row_offset_idx") + c = cell.get("start_col_offset_idx") + text = cell.get("text", "").strip() + rows[r].append((c, text)) + + # build a tab‑separated line for each row, in order + flat_rows = [] + for r in sorted(rows): + cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])] + flat_rows.append("\t".join(cells)) + + chunks.append({ + "page": page_no, + "type": "table", + "table_index": t_idx, + "text": "\n".join(flat_rows) + }) + + return { + "id": origin.get("binary_hash"), + "filename": origin.get("filename"), + "mimetype": origin.get("mimetype"), + "chunks": chunks + } + +def process_document_sync(file_path: str): + """Synchronous document processing function for multiprocessing""" + from collections import defaultdict + + # Get the cached converter for this worker + converter = get_worker_converter() + + # Compute file hash + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + while True: + chunk = f.read(1 << 20) + if not chunk: + break + sha256.update(chunk) + file_hash = sha256.hexdigest() + + # Convert with docling + result = converter.convert(file_path) + full_doc = result.document.export_to_dict() + + # Extract relevant content (same logic as extract_relevant) + origin = full_doc.get("origin", {}) + texts = full_doc.get("texts", []) + + page_texts = defaultdict(list) + for txt in texts: + prov = txt.get("prov", []) + page_no = prov[0].get("page_no") if prov else None + if page_no is not None: + page_texts[page_no].append(txt.get("text", "").strip()) + + chunks = [] + for page in sorted(page_texts): + joined = "\n".join(page_texts[page]) + chunks.append({ + "page": page, + "text": joined + }) + + return { + "id": file_hash, + "filename": origin.get("filename"), + "mimetype": origin.get("mimetype"), + "chunks": chunks, + "file_path": file_path + } \ No newline at end of file diff --git a/src/utils/gpu_detection.py b/src/utils/gpu_detection.py new file mode 100644 index 00000000..ed9d8e81 --- /dev/null +++ b/src/utils/gpu_detection.py @@ -0,0 +1,34 @@ +import multiprocessing +import os + +def detect_gpu_devices(): + """Detect if GPU devices are actually available""" + try: + import torch + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + return True, torch.cuda.device_count() + except ImportError: + pass + + try: + import subprocess + result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) + if result.returncode == 0: + return True, "detected" + except (subprocess.SubprocessError, FileNotFoundError): + pass + + return False, 0 + +def get_worker_count(): + """Get optimal worker count based on GPU availability""" + has_gpu_devices, gpu_count = detect_gpu_devices() + + if has_gpu_devices: + default_workers = min(4, multiprocessing.cpu_count() // 2) + print(f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)") + else: + default_workers = multiprocessing.cpu_count() + print(f"CPU-only mode enabled - using full concurrency ({default_workers} workers)") + + return int(os.getenv("MAX_WORKERS", default_workers)) \ No newline at end of file