This commit is contained in:
phact 2025-07-30 11:18:19 -04:00
parent 005df20558
commit c9182184cf
20 changed files with 1501 additions and 0 deletions

0
src/api/__init__.py Normal file
View file

80
src/api/auth.py Normal file
View file

@ -0,0 +1,80 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
async def auth_init(request: Request, auth_service, session_manager):
"""Initialize OAuth flow for authentication or data source connection"""
try:
data = await request.json()
provider = data.get("provider")
purpose = data.get("purpose", "data_source")
connection_name = data.get("name", f"{provider}_{purpose}")
redirect_uri = data.get("redirect_uri")
user = getattr(request.state, 'user', None)
user_id = user.user_id if user else None
result = await auth_service.init_oauth(
provider, purpose, connection_name, redirect_uri, user_id
)
return JSONResponse(result)
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse({"error": f"Failed to initialize OAuth: {str(e)}"}, status_code=500)
async def auth_callback(request: Request, auth_service, session_manager):
"""Handle OAuth callback - exchange authorization code for tokens"""
try:
data = await request.json()
connection_id = data.get("connection_id")
authorization_code = data.get("authorization_code")
state = data.get("state")
result = await auth_service.handle_oauth_callback(
connection_id, authorization_code, state
)
# If this is app auth, set JWT cookie
if result.get("purpose") == "app_auth" and result.get("jwt_token"):
response = JSONResponse({
k: v for k, v in result.items() if k != "jwt_token"
})
response.set_cookie(
key="auth_token",
value=result["jwt_token"],
httponly=True,
secure=False,
samesite="lax",
max_age=7 * 24 * 60 * 60 # 7 days
)
return response
else:
return JSONResponse(result)
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse({"error": f"Callback failed: {str(e)}"}, status_code=500)
async def auth_me(request: Request, auth_service, session_manager):
"""Get current user information"""
result = await auth_service.get_user_info(request)
return JSONResponse(result)
async def auth_logout(request: Request, auth_service, session_manager):
"""Logout user by clearing auth cookie"""
response = JSONResponse({
"status": "logged_out",
"message": "Successfully logged out"
})
# Clear the auth cookie
response.delete_cookie(
key="auth_token",
httponly=True,
secure=False,
samesite="lax"
)
return response

59
src/api/chat.py Normal file
View file

@ -0,0 +1,59 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse
async def chat_endpoint(request: Request, chat_service, session_manager):
"""Handle chat requests"""
data = await request.json()
prompt = data.get("prompt", "")
previous_response_id = data.get("previous_response_id")
stream = data.get("stream", False)
user = request.state.user
user_id = user.user_id
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
if stream:
return StreamingResponse(
await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
else:
result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False)
return JSONResponse(result)
async def langflow_endpoint(request: Request, chat_service, session_manager):
"""Handle Langflow chat requests"""
data = await request.json()
prompt = data.get("prompt", "")
previous_response_id = data.get("previous_response_id")
stream = data.get("stream", False)
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
try:
if stream:
return StreamingResponse(
await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=True),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
else:
result = await chat_service.langflow_chat(prompt, previous_response_id=previous_response_id, stream=False)
return JSONResponse(result)
except Exception as e:
return JSONResponse({"error": f"Langflow request failed: {str(e)}"}, status_code=500)

81
src/api/connectors.py Normal file
View file

@ -0,0 +1,81 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
async def connector_sync(request: Request, connector_service, session_manager):
"""Sync files from a connector connection"""
data = await request.json()
connection_id = data.get("connection_id")
max_files = data.get("max_files")
if not connection_id:
return JSONResponse({"error": "connection_id is required"}, status_code=400)
try:
print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}")
# Verify user owns this connection
user = request.state.user
print(f"[DEBUG] User: {user.user_id}")
connection_config = await connector_service.connection_manager.get_connection(connection_id)
print(f"[DEBUG] Got connection config: {connection_config is not None}")
if not connection_config:
return JSONResponse({"error": "Connection not found"}, status_code=404)
if connection_config.user_id != user.user_id:
return JSONResponse({"error": "Access denied"}, status_code=403)
print(f"[DEBUG] About to call sync_connector_files")
task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files)
print(f"[DEBUG] Got task_id: {task_id}")
return JSONResponse({
"task_id": task_id,
"status": "sync_started",
"message": f"Started syncing files from connection {connection_id}"
},
status_code=201
)
except Exception as e:
import sys
import traceback
error_msg = f"[ERROR] Connector sync failed: {str(e)}"
print(error_msg, file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
sys.stderr.flush()
return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500)
async def connector_status(request: Request, connector_service, session_manager):
"""Get connector status for authenticated user"""
connector_type = request.path_params.get("connector_type", "google_drive")
user = request.state.user
# Get connections for this connector type and user
connections = await connector_service.connection_manager.list_connections(
user_id=user.user_id,
connector_type=connector_type
)
# Check if there are any active connections
active_connections = [conn for conn in connections if conn.is_active]
has_authenticated_connection = len(active_connections) > 0
return JSONResponse({
"connector_type": connector_type,
"authenticated": has_authenticated_connection,
"status": "connected" if has_authenticated_connection else "not_connected",
"connections": [
{
"connection_id": conn.connection_id,
"name": conn.name,
"is_active": conn.is_active,
"created_at": conn.created_at.isoformat(),
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None
}
for conn in connections
]
})

13
src/api/search.py Normal file
View file

@ -0,0 +1,13 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
async def search(request: Request, search_service, session_manager):
"""Search for documents"""
payload = await request.json()
query = payload.get("query")
if not query:
return JSONResponse({"error": "Query is required"}, status_code=400)
user = request.state.user
result = await search_service.search(query, user_id=user.user_id)
return JSONResponse(result)

78
src/api/upload.py Normal file
View file

@ -0,0 +1,78 @@
import os
from starlette.requests import Request
from starlette.responses import JSONResponse
async def upload(request: Request, document_service, session_manager):
"""Upload a single file"""
form = await request.form()
upload_file = form["file"]
user = request.state.user
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id)
return JSONResponse(result)
async def upload_path(request: Request, task_service, session_manager):
"""Upload all files from a directory path"""
payload = await request.json()
base_dir = payload.get("path")
if not base_dir or not os.path.isdir(base_dir):
return JSONResponse({"error": "Invalid path"}, status_code=400)
file_paths = [os.path.join(root, fn)
for root, _, files in os.walk(base_dir)
for fn in files]
if not file_paths:
return JSONResponse({"error": "No files found in directory"}, status_code=400)
user = request.state.user
task_id = await task_service.create_upload_task(user.user_id, file_paths)
return JSONResponse({
"task_id": task_id,
"total_files": len(file_paths),
"status": "accepted"
}, status_code=201)
async def upload_context(request: Request, document_service, chat_service, session_manager):
"""Upload a file and add its content as context to the current conversation"""
form = await request.form()
upload_file = form["file"]
filename = upload_file.filename or "uploaded_document"
# Get optional parameters
previous_response_id = form.get("previous_response_id")
endpoint = form.get("endpoint", "langflow")
# Process document and extract content
doc_result = await document_service.process_upload_context(upload_file, filename)
# Send document content as user message to get proper response_id
response_text, response_id = await chat_service.upload_context_chat(
doc_result["content"],
filename,
previous_response_id=previous_response_id,
endpoint=endpoint
)
response_data = {
"status": "context_added",
"filename": doc_result["filename"],
"pages": doc_result["pages"],
"content_length": doc_result["content_length"],
"response_id": response_id,
"confirmation": response_text
}
return JSONResponse(response_data)
async def task_status(request: Request, task_service, session_manager):
"""Get the status of an upload task"""
task_id = request.path_params.get("task_id")
user = request.state.user
task_status_result = task_service.get_task_status(user.user_id, task_id)
if not task_status_result:
return JSONResponse({"error": "Task not found"}, status_code=404)
return JSONResponse(task_status_result)

0
src/config/__init__.py Normal file
View file

105
src/config/settings.py Normal file
View file

@ -0,0 +1,105 @@
import os
from dotenv import load_dotenv
from opensearchpy import AsyncOpenSearch
from opensearchpy._async.http_aiohttp import AIOHttpConnection
from docling.document_converter import DocumentConverter
from agentd.patch import patch_openai_with_mcp
from openai import AsyncOpenAI
load_dotenv()
load_dotenv("../")
# Environment variables
OPENSEARCH_HOST = os.getenv("OPENSEARCH_HOST", "localhost")
OPENSEARCH_PORT = int(os.getenv("OPENSEARCH_PORT", "9200"))
OPENSEARCH_USERNAME = os.getenv("OPENSEARCH_USERNAME", "admin")
OPENSEARCH_PASSWORD = os.getenv("OPENSEARCH_PASSWORD")
LANGFLOW_URL = os.getenv("LANGFLOW_URL", "http://localhost:7860")
FLOW_ID = os.getenv("FLOW_ID")
LANGFLOW_KEY = os.getenv("LANGFLOW_SECRET_KEY")
SESSION_SECRET = os.getenv("SESSION_SECRET", "your-secret-key-change-in-production")
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# OpenSearch configuration
INDEX_NAME = "documents"
VECTOR_DIM = 1536
EMBED_MODEL = "text-embedding-3-small"
INDEX_BODY = {
"settings": {
"index": {"knn": True},
"number_of_shards": 1,
"number_of_replicas": 1
},
"mappings": {
"properties": {
"document_id": { "type": "keyword" },
"filename": { "type": "keyword" },
"mimetype": { "type": "keyword" },
"page": { "type": "integer" },
"text": { "type": "text" },
"chunk_embedding": {
"type": "knn_vector",
"dimension": VECTOR_DIM,
"method": {
"name": "disk_ann",
"engine": "jvector",
"space_type": "l2",
"parameters": {
"ef_construction": 100,
"m": 16
}
}
},
"source_url": { "type": "keyword" },
"connector_type": { "type": "keyword" },
"owner": { "type": "keyword" },
"allowed_users": { "type": "keyword" },
"allowed_groups": { "type": "keyword" },
"user_permissions": { "type": "object" },
"group_permissions": { "type": "object" },
"created_time": { "type": "date" },
"modified_time": { "type": "date" },
"indexed_time": { "type": "date" },
"metadata": { "type": "object" }
}
}
}
class AppClients:
def __init__(self):
self.opensearch = None
self.langflow_client = None
self.patched_async_client = None
self.converter = None
def initialize(self):
# Initialize OpenSearch client
self.opensearch = AsyncOpenSearch(
hosts=[{"host": OPENSEARCH_HOST, "port": OPENSEARCH_PORT}],
connection_class=AIOHttpConnection,
scheme="https",
use_ssl=True,
verify_certs=False,
ssl_assert_fingerprint=None,
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
http_compress=True,
)
# Initialize Langflow client
self.langflow_client = AsyncOpenAI(
base_url=f"{LANGFLOW_URL}/api/v1",
api_key=LANGFLOW_KEY
)
# Initialize patched OpenAI client
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
# Initialize Docling converter
self.converter = DocumentConverter()
return self
# Global clients instance
clients = AppClients()

234
src/main.py Normal file
View file

@ -0,0 +1,234 @@
import asyncio
import atexit
import torch
from functools import partial
from starlette.applications import Starlette
from starlette.routing import Route
# Configuration and setup
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
from utils.gpu_detection import detect_gpu_devices
# Services
from services.document_service import DocumentService
from services.search_service import SearchService
from services.task_service import TaskService
from services.auth_service import AuthService
from services.chat_service import ChatService
# Existing services
from connectors.service import ConnectorService
from session_manager import SessionManager
from auth_middleware import require_auth, optional_auth
# API endpoints
from api import upload, search, chat, auth, connectors
print("CUDA available:", torch.cuda.is_available())
print("CUDA version PyTorch was built with:", torch.version.cuda)
async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
retry_delay = 2
for attempt in range(max_retries):
try:
await clients.opensearch.info()
print("OpenSearch is ready!")
return
except Exception as e:
print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
raise Exception("OpenSearch failed to become ready")
async def init_index():
"""Initialize OpenSearch index"""
await wait_for_opensearch()
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
print(f"Created index '{INDEX_NAME}'")
else:
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
def initialize_services():
"""Initialize all services and their dependencies"""
# Initialize clients
clients.initialize()
# Initialize session manager
session_manager = SessionManager(SESSION_SECRET)
# Initialize services
document_service = DocumentService()
search_service = SearchService()
task_service = TaskService(document_service)
chat_service = ChatService()
# Set process pool for document service
document_service.process_pool = task_service.process_pool
# Initialize connector service
connector_service = ConnectorService(
opensearch_client=clients.opensearch,
patched_async_client=clients.patched_async_client,
process_pool=task_service.process_pool,
embed_model="text-embedding-3-small",
index_name=INDEX_NAME
)
# Initialize auth service
auth_service = AuthService(session_manager, connector_service)
return {
'document_service': document_service,
'search_service': search_service,
'task_service': task_service,
'chat_service': chat_service,
'auth_service': auth_service,
'connector_service': connector_service,
'session_manager': session_manager
}
def create_app():
"""Create and configure the Starlette application"""
services = initialize_services()
# Create route handlers with service dependencies injected
routes = [
# Upload endpoints
Route("/upload",
require_auth(services['session_manager'])(
partial(upload.upload,
document_service=services['document_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/upload_context",
require_auth(services['session_manager'])(
partial(upload.upload_context,
document_service=services['document_service'],
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/upload_path",
require_auth(services['session_manager'])(
partial(upload.upload_path,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/tasks/{task_id}",
require_auth(services['session_manager'])(
partial(upload.task_status,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
# Search endpoint
Route("/search",
require_auth(services['session_manager'])(
partial(search.search,
search_service=services['search_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Chat endpoints
Route("/chat",
require_auth(services['session_manager'])(
partial(chat.chat_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/langflow",
require_auth(services['session_manager'])(
partial(chat.langflow_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Authentication endpoints
Route("/auth/init",
optional_auth(services['session_manager'])(
partial(auth.auth_init,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/auth/callback",
partial(auth.auth_callback,
auth_service=services['auth_service'],
session_manager=services['session_manager']),
methods=["POST"]),
Route("/auth/me",
optional_auth(services['session_manager'])(
partial(auth.auth_me,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/auth/logout",
require_auth(services['session_manager'])(
partial(auth.auth_logout,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Connector endpoints
Route("/connectors/sync",
require_auth(services['session_manager'])(
partial(connectors.connector_sync,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/connectors/status/{connector_type}",
require_auth(services['session_manager'])(
partial(connectors.connector_status,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
]
app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup
return app
async def startup():
"""Application startup tasks"""
await init_index()
# Get services from app state if needed for initialization
# services = app.state.services
# await services['connector_service'].initialize()
def cleanup():
"""Cleanup on application shutdown"""
# This will be called on exit to cleanup process pools
pass
if __name__ == "__main__":
import uvicorn
# Register cleanup function
atexit.register(cleanup)
# Create app
app = create_app()
# Run startup tasks
asyncio.run(startup())
# Run the server
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
reload=False, # Disable reload since we're running from main
)

0
src/models/__init__.py Normal file
View file

32
src/models/tasks.py Normal file
View file

@ -0,0 +1,32 @@
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
class TaskStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class FileTask:
file_path: str
status: TaskStatus = TaskStatus.PENDING
result: Optional[dict] = None
error: Optional[str] = None
retry_count: int = 0
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
@dataclass
class UploadTask:
task_id: str
total_files: int
processed_files: int = 0
successful_files: int = 0
failed_files: int = 0
file_tasks: Dict[str, FileTask] = field(default_factory=dict)
status: TaskStatus = TaskStatus.PENDING
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)

0
src/services/__init__.py Normal file
View file

View file

@ -0,0 +1,213 @@
import os
import uuid
import json
import httpx
import aiofiles
from datetime import datetime, timedelta
from typing import Optional
from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET
from session_manager import SessionManager
class AuthService:
def __init__(self, session_manager: SessionManager, connector_service=None):
self.session_manager = session_manager
self.connector_service = connector_service
self.used_auth_codes = set() # Track used authorization codes
async def init_oauth(self, provider: str, purpose: str, connection_name: str,
redirect_uri: str, user_id: str = None) -> dict:
"""Initialize OAuth flow for authentication or data source connection"""
if provider != "google":
raise ValueError("Unsupported provider")
if not redirect_uri:
raise ValueError("redirect_uri is required")
if not GOOGLE_OAUTH_CLIENT_ID:
raise ValueError("Google OAuth client ID not configured")
# Create connection configuration
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
config = {
"client_id": GOOGLE_OAUTH_CLIENT_ID,
"token_file": token_file,
"provider": provider,
"purpose": purpose,
"redirect_uri": redirect_uri
}
# Create connection in manager
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
connection_id = await self.connector_service.connection_manager.create_connection(
connector_type=connector_type,
name=connection_name,
config=config,
user_id=user_id
)
# Return OAuth configuration for client-side flow
scopes = [
'openid', 'email', 'profile',
'https://www.googleapis.com/auth/drive.readonly',
'https://www.googleapis.com/auth/drive.metadata.readonly'
]
oauth_config = {
"client_id": GOOGLE_OAUTH_CLIENT_ID,
"scopes": scopes,
"redirect_uri": redirect_uri,
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
"token_endpoint": "https://oauth2.googleapis.com/token"
}
return {
"connection_id": connection_id,
"oauth_config": oauth_config
}
async def handle_oauth_callback(self, connection_id: str, authorization_code: str,
state: str = None) -> dict:
"""Handle OAuth callback - exchange authorization code for tokens"""
if not all([connection_id, authorization_code]):
raise ValueError("Missing required parameters (connection_id, authorization_code)")
# Check if authorization code has already been used
if authorization_code in self.used_auth_codes:
raise ValueError("Authorization code already used")
# Mark code as used to prevent duplicate requests
self.used_auth_codes.add(authorization_code)
try:
# Get connection config
connection_config = await self.connector_service.connection_manager.get_connection(connection_id)
if not connection_config:
raise ValueError("Connection not found")
# Exchange authorization code for tokens
redirect_uri = connection_config.config.get("redirect_uri")
if not redirect_uri:
raise ValueError("Redirect URI not found in connection config")
token_url = "https://oauth2.googleapis.com/token"
token_payload = {
"code": authorization_code,
"client_id": connection_config.config["client_id"],
"client_secret": GOOGLE_OAUTH_CLIENT_SECRET,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code"
}
async with httpx.AsyncClient() as client:
token_response = await client.post(token_url, data=token_payload)
if token_response.status_code != 200:
raise Exception(f"Token exchange failed: {token_response.text}")
token_data = token_response.json()
# Store tokens in the token file
token_file_data = {
"token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
"scopes": [
"openid", "email", "profile",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly"
]
}
# Add expiry if provided
if token_data.get("expires_in"):
expiry = datetime.now() + timedelta(seconds=int(token_data["expires_in"]))
token_file_data["expiry"] = expiry.isoformat()
# Save tokens to file
token_file_path = connection_config.config["token_file"]
async with aiofiles.open(token_file_path, 'w') as f:
await f.write(json.dumps(token_file_data, indent=2))
# Route based on purpose
purpose = connection_config.config.get("purpose", "data_source")
if purpose == "app_auth":
return await self._handle_app_auth(connection_id, connection_config, token_data)
else:
return await self._handle_data_source_auth(connection_id, connection_config)
except Exception as e:
# Remove used code from set if we failed
self.used_auth_codes.discard(authorization_code)
raise e
async def _handle_app_auth(self, connection_id: str, connection_config, token_data: dict) -> dict:
"""Handle app authentication - create user session"""
jwt_token = await self.session_manager.create_user_session(token_data["access_token"])
if jwt_token:
# Get the user info to create a persistent Google Drive connection
user_info = await self.session_manager.get_user_info_from_token(token_data["access_token"])
user_id = user_info["id"] if user_info else None
response_data = {
"status": "authenticated",
"purpose": "app_auth",
"redirect": "/",
"jwt_token": jwt_token # Include JWT token in response
}
if user_id:
# Convert the temporary auth connection to a persistent Google Drive connection
await self.connector_service.connection_manager.update_connection(
connection_id=connection_id,
connector_type="google_drive",
name=f"Google Drive ({user_info.get('email', 'Unknown')})",
user_id=user_id,
config={
**connection_config.config,
"purpose": "data_source",
"user_email": user_info.get("email")
}
)
response_data["google_drive_connection_id"] = connection_id
else:
# Fallback: delete connection if we can't get user info
await self.connector_service.connection_manager.delete_connection(connection_id)
return response_data
else:
# Clean up connection if session creation failed
await self.connector_service.connection_manager.delete_connection(connection_id)
raise Exception("Failed to create user session")
async def _handle_data_source_auth(self, connection_id: str, connection_config) -> dict:
"""Handle data source connection - keep the connection for syncing"""
return {
"status": "authenticated",
"connection_id": connection_id,
"purpose": "data_source",
"connector_type": connection_config.connector_type
}
async def get_user_info(self, request) -> Optional[dict]:
"""Get current user information from request"""
user = getattr(request.state, 'user', None)
if user:
return {
"authenticated": True,
"user": {
"user_id": user.user_id,
"email": user.email,
"name": user.name,
"picture": user.picture,
"provider": user.provider,
"last_login": user.last_login.isoformat() if user.last_login else None
}
}
else:
return {
"authenticated": False,
"user": None
}

View file

@ -0,0 +1,47 @@
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
class ChatService:
async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False):
"""Handle chat requests using the patched OpenAI client"""
if not prompt:
raise ValueError("Prompt is required")
if stream:
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
else:
response_text, response_id = await async_chat(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
response_data = {"response": response_text}
if response_id:
response_data["response_id"] = response_id
return response_data
async def langflow_chat(self, prompt: str, previous_response_id: str = None, stream: bool = False):
"""Handle Langflow chat requests"""
if not prompt:
raise ValueError("Prompt is required")
if not LANGFLOW_URL or not FLOW_ID or not LANGFLOW_KEY:
raise ValueError("LANGFLOW_URL, FLOW_ID, and LANGFLOW_KEY environment variables are required")
if stream:
return async_langflow_stream(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id)
else:
response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, prompt, previous_response_id=previous_response_id)
response_data = {"response": response_text}
if response_id:
response_data["response_id"] = response_id
return response_data
async def upload_context_chat(self, document_content: str, filename: str,
previous_response_id: str = None, endpoint: str = "langflow"):
"""Send document content as user message to get proper response_id"""
document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{document_content}\n\nPlease confirm you've received this document and are ready to answer questions about it."
if endpoint == "langflow":
response_text, response_id = await async_langflow(clients.langflow_client, FLOW_ID, document_prompt, previous_response_id=previous_response_id)
else: # chat
response_text, response_id = await async_chat(clients.patched_async_client, document_prompt, previous_response_id=previous_response_id)
return response_text, response_id

View file

@ -0,0 +1,184 @@
import datetime
import hashlib
import tempfile
import os
import aiofiles
from io import BytesIO
from docling_core.types.io import DocumentStream
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from utils.document_processing import extract_relevant, process_document_sync
class DocumentService:
def __init__(self, process_pool=None):
self.process_pool = process_pool
async def process_file_common(self, file_path: str, file_hash: str = None, owner_user_id: str = None):
"""
Common processing logic for both upload and upload_path.
1. Optionally compute SHA256 hash if not provided.
2. Convert with docling and extract relevant content.
3. Add embeddings.
4. Index into OpenSearch.
"""
if file_hash is None:
sha256 = hashlib.sha256()
async with aiofiles.open(file_path, "rb") as f:
while True:
chunk = await f.read(1 << 20)
if not chunk:
break
sha256.update(chunk)
file_hash = sha256.hexdigest()
exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash)
if exists:
return {"status": "unchanged", "id": file_hash}
# convert and extract
result = clients.converter.convert(file_path)
full_doc = result.document.export_to_dict()
slim_doc = extract_relevant(full_doc)
texts = [c["text"] for c in slim_doc["chunks"]]
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
embeddings = [d.embedding for d in resp.data]
# Index each chunk as a separate document
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
chunk_doc = {
"document_id": file_hash,
"filename": slim_doc["filename"],
"mimetype": slim_doc["mimetype"],
"page": chunk["page"],
"text": chunk["text"],
"chunk_embedding": vect,
"owner": owner_user_id,
"indexed_time": datetime.datetime.now().isoformat()
}
chunk_id = f"{file_hash}_{i}"
await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
return {"status": "indexed", "id": file_hash}
async def process_upload_file(self, upload_file, owner_user_id: str = None):
"""Process an uploaded file from form data"""
sha256 = hashlib.sha256()
tmp = tempfile.NamedTemporaryFile(delete=False)
try:
while True:
chunk = await upload_file.read(1 << 20)
if not chunk:
break
sha256.update(chunk)
tmp.write(chunk)
tmp.flush()
file_hash = sha256.hexdigest()
exists = await clients.opensearch.exists(index=INDEX_NAME, id=file_hash)
if exists:
return {"status": "unchanged", "id": file_hash}
result = await self.process_file_common(tmp.name, file_hash, owner_user_id=owner_user_id)
return result
finally:
tmp.close()
os.remove(tmp.name)
async def process_upload_context(self, upload_file, filename: str = None):
"""Process uploaded file and return content for context"""
import io
if not filename:
filename = upload_file.filename or "uploaded_document"
# Stream file content into BytesIO
content = io.BytesIO()
while True:
chunk = await upload_file.read(1 << 20) # 1MB chunks
if not chunk:
break
content.write(chunk)
content.seek(0) # Reset to beginning for reading
# Create DocumentStream and process with docling
doc_stream = DocumentStream(name=filename, stream=content)
result = clients.converter.convert(doc_stream)
full_doc = result.document.export_to_dict()
slim_doc = extract_relevant(full_doc)
# Extract all text content
all_text = []
for chunk in slim_doc["chunks"]:
all_text.append(f"Page {chunk['page']}:\n{chunk['text']}")
full_content = "\n\n".join(all_text)
return {
"filename": filename,
"content": full_content,
"pages": len(slim_doc["chunks"]),
"content_length": len(full_content)
}
async def process_single_file_task(self, upload_task, file_path: str):
"""Process a single file and update task tracking - used by task service"""
from models.tasks import TaskStatus
import time
import asyncio
file_task = upload_task.file_tasks[file_path]
file_task.status = TaskStatus.RUNNING
file_task.updated_at = time.time()
try:
# Check if file already exists in index
loop = asyncio.get_event_loop()
# Run CPU-intensive docling processing in separate process
slim_doc = await loop.run_in_executor(self.process_pool, process_document_sync, file_path)
# Check if already indexed
exists = await clients.opensearch.exists(index=INDEX_NAME, id=slim_doc["id"])
if exists:
result = {"status": "unchanged", "id": slim_doc["id"]}
else:
# Generate embeddings and index (I/O bound, keep in main process)
texts = [c["text"] for c in slim_doc["chunks"]]
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
embeddings = [d.embedding for d in resp.data]
# Index each chunk
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
chunk_doc = {
"document_id": slim_doc["id"],
"filename": slim_doc["filename"],
"mimetype": slim_doc["mimetype"],
"page": chunk["page"],
"text": chunk["text"],
"chunk_embedding": vect
}
chunk_id = f"{slim_doc['id']}_{i}"
await clients.opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
result = {"status": "indexed", "id": slim_doc["id"]}
result["path"] = file_path
file_task.status = TaskStatus.COMPLETED
file_task.result = result
upload_task.successful_files += 1
except Exception as e:
print(f"[ERROR] Failed to process file {file_path}: {e}")
import traceback
traceback.print_exc()
file_task.status = TaskStatus.FAILED
file_task.error = str(e)
upload_task.failed_files += 1
finally:
file_task.updated_at = time.time()
upload_task.processed_files += 1
upload_task.updated_at = time.time()
if upload_task.processed_files >= upload_task.total_files:
upload_task.status = TaskStatus.COMPLETED

View file

@ -0,0 +1,80 @@
from typing import Any, Dict, Optional
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
class SearchService:
@tool
async def search_tool(self, query: str, user_id: str = None) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
user_id (str): user ID for access control (optional)
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
# Embed the query
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
query_embedding = resp.data[0].embedding
# Base query structure
search_body = {
"query": {
"bool": {
"must": [
{
"knn": {
"chunk_embedding": {
"vector": query_embedding,
"k": 10
}
}
}
]
}
},
"_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"],
"size": 10
}
# Require authentication - no anonymous access to search
if not user_id:
return {"results": [], "error": "Authentication required"}
# Authenticated user access control
# User can access documents if:
# 1. They own the document (owner field matches user_id)
# 2. They're in allowed_users list
# 3. Document has no ACL (public documents)
# TODO: Add group access control later
should_clauses = [
{"term": {"owner": user_id}},
{"term": {"allowed_users": user_id}},
{"bool": {"must_not": {"exists": {"field": "owner"}}}} # Public docs
]
search_body["query"]["bool"]["should"] = should_clauses
search_body["query"]["bool"]["minimum_should_match"] = 1
results = await clients.opensearch.search(index=INDEX_NAME, body=search_body)
# Transform results
chunks = []
for hit in results["hits"]["hits"]:
chunks.append({
"filename": hit["_source"]["filename"],
"mimetype": hit["_source"]["mimetype"],
"page": hit["_source"]["page"],
"text": hit["_source"]["text"],
"score": hit["_score"],
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner")
})
return {"results": chunks}
async def search(self, query: str, user_id: str = None) -> Dict[str, Any]:
"""Public search method for API endpoints"""
return await self.search_tool(query, user_id)

View file

@ -0,0 +1,112 @@
import asyncio
import uuid
import time
import random
from typing import Dict
from concurrent.futures import ProcessPoolExecutor
from models.tasks import TaskStatus, UploadTask, FileTask
from utils.gpu_detection import get_worker_count
class TaskService:
def __init__(self, document_service=None):
self.document_service = document_service
self.task_store: Dict[str, Dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask}
self.background_tasks = set()
# Initialize process pool
max_workers = get_worker_count()
self.process_pool = ProcessPoolExecutor(max_workers=max_workers)
print(f"Process pool initialized with {max_workers} workers")
async def exponential_backoff_delay(self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
"""Apply exponential backoff with jitter"""
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
await asyncio.sleep(delay)
async def create_upload_task(self, user_id: str, file_paths: list) -> str:
"""Create a new upload task for bulk file processing"""
task_id = str(uuid.uuid4())
upload_task = UploadTask(
task_id=task_id,
total_files=len(file_paths),
file_tasks={path: FileTask(file_path=path) for path in file_paths}
)
if user_id not in self.task_store:
self.task_store[user_id] = {}
self.task_store[user_id][task_id] = upload_task
# Start background processing
background_task = asyncio.create_task(self.background_upload_processor(user_id, task_id))
self.background_tasks.add(background_task)
background_task.add_done_callback(self.background_tasks.discard)
return task_id
async def background_upload_processor(self, user_id: str, task_id: str) -> None:
"""Background task to process all files in an upload job with concurrency control"""
try:
upload_task = self.task_store[user_id][task_id]
upload_task.status = TaskStatus.RUNNING
upload_task.updated_at = time.time()
# Process files with limited concurrency to avoid overwhelming the system
max_workers = get_worker_count()
semaphore = asyncio.Semaphore(max_workers * 2) # Allow 2x process pool size for async I/O
async def process_with_semaphore(file_path: str):
async with semaphore:
await self.document_service.process_single_file_task(upload_task, file_path)
tasks = [
process_with_semaphore(file_path)
for file_path in upload_task.file_tasks.keys()
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
import traceback
traceback.print_exc()
if user_id in self.task_store and task_id in self.task_store[user_id]:
self.task_store[user_id][task_id].status = TaskStatus.FAILED
self.task_store[user_id][task_id].updated_at = time.time()
def get_task_status(self, user_id: str, task_id: str) -> dict:
"""Get the status of a specific upload task"""
if (not task_id or
user_id not in self.task_store or
task_id not in self.task_store[user_id]):
return None
upload_task = self.task_store[user_id][task_id]
file_statuses = {}
for file_path, file_task in upload_task.file_tasks.items():
file_statuses[file_path] = {
"status": file_task.status.value,
"result": file_task.result,
"error": file_task.error,
"retry_count": file_task.retry_count,
"created_at": file_task.created_at,
"updated_at": file_task.updated_at
}
return {
"task_id": upload_task.task_id,
"status": upload_task.status.value,
"total_files": upload_task.total_files,
"processed_files": upload_task.processed_files,
"successful_files": upload_task.successful_files,
"failed_files": upload_task.failed_files,
"created_at": upload_task.created_at,
"updated_at": upload_task.updated_at,
"files": file_statuses
}
def shutdown(self):
"""Cleanup process pool"""
if hasattr(self, 'process_pool'):
self.process_pool.shutdown(wait=True)

0
src/utils/__init__.py Normal file
View file

View file

@ -0,0 +1,149 @@
import hashlib
import os
from collections import defaultdict
from docling.document_converter import DocumentConverter
from .gpu_detection import detect_gpu_devices
# Global converter cache for worker processes
_worker_converter = None
def get_worker_converter():
"""Get or create a DocumentConverter instance for this worker process"""
global _worker_converter
if _worker_converter is None:
from docling.document_converter import DocumentConverter
# Configure GPU settings for this worker
has_gpu_devices, _ = detect_gpu_devices()
if not has_gpu_devices:
# Force CPU-only mode in subprocess
os.environ['USE_CPU_ONLY'] = 'true'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['TORCH_USE_CUDA_DSA'] = '0'
# Try to disable CUDA in torch if available
try:
import torch
torch.cuda.is_available = lambda: False
except ImportError:
pass
else:
# GPU mode - let libraries use GPU if available
os.environ.pop('USE_CPU_ONLY', None)
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars
print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})")
_worker_converter = DocumentConverter()
print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})")
return _worker_converter
def extract_relevant(doc_dict: dict) -> dict:
"""
Given the full export_to_dict() result:
- Grabs origin metadata (hash, filename, mimetype)
- Finds every text fragment in `texts`, groups them by page_no
- Flattens tables in `tables` into tab-separated text, grouping by row
- Concatenates each page's fragments and each table into its own chunk
Returns a slimmed dict ready for indexing, with each chunk under "text".
"""
origin = doc_dict.get("origin", {})
chunks = []
# 1) process free-text fragments
page_texts = defaultdict(list)
for txt in doc_dict.get("texts", []):
prov = txt.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
if page_no is not None:
page_texts[page_no].append(txt.get("text", "").strip())
for page in sorted(page_texts):
chunks.append({
"page": page,
"type": "text",
"text": "\n".join(page_texts[page])
})
# 2) process tables
for t_idx, table in enumerate(doc_dict.get("tables", [])):
prov = table.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
# group cells by their row index
rows = defaultdict(list)
for cell in table.get("data").get("table_cells", []):
r = cell.get("start_row_offset_idx")
c = cell.get("start_col_offset_idx")
text = cell.get("text", "").strip()
rows[r].append((c, text))
# build a tabseparated line for each row, in order
flat_rows = []
for r in sorted(rows):
cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])]
flat_rows.append("\t".join(cells))
chunks.append({
"page": page_no,
"type": "table",
"table_index": t_idx,
"text": "\n".join(flat_rows)
})
return {
"id": origin.get("binary_hash"),
"filename": origin.get("filename"),
"mimetype": origin.get("mimetype"),
"chunks": chunks
}
def process_document_sync(file_path: str):
"""Synchronous document processing function for multiprocessing"""
from collections import defaultdict
# Get the cached converter for this worker
converter = get_worker_converter()
# Compute file hash
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
chunk = f.read(1 << 20)
if not chunk:
break
sha256.update(chunk)
file_hash = sha256.hexdigest()
# Convert with docling
result = converter.convert(file_path)
full_doc = result.document.export_to_dict()
# Extract relevant content (same logic as extract_relevant)
origin = full_doc.get("origin", {})
texts = full_doc.get("texts", [])
page_texts = defaultdict(list)
for txt in texts:
prov = txt.get("prov", [])
page_no = prov[0].get("page_no") if prov else None
if page_no is not None:
page_texts[page_no].append(txt.get("text", "").strip())
chunks = []
for page in sorted(page_texts):
joined = "\n".join(page_texts[page])
chunks.append({
"page": page,
"text": joined
})
return {
"id": file_hash,
"filename": origin.get("filename"),
"mimetype": origin.get("mimetype"),
"chunks": chunks,
"file_path": file_path
}

View file

@ -0,0 +1,34 @@
import multiprocessing
import os
def detect_gpu_devices():
"""Detect if GPU devices are actually available"""
try:
import torch
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return True, torch.cuda.device_count()
except ImportError:
pass
try:
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
if result.returncode == 0:
return True, "detected"
except (subprocess.SubprocessError, FileNotFoundError):
pass
return False, 0
def get_worker_count():
"""Get optimal worker count based on GPU availability"""
has_gpu_devices, gpu_count = detect_gpu_devices()
if has_gpu_devices:
default_workers = min(4, multiprocessing.cpu_count() // 2)
print(f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)")
else:
default_workers = multiprocessing.cpu_count()
print(f"CPU-only mode enabled - using full concurrency ({default_workers} workers)")
return int(os.getenv("MAX_WORKERS", default_workers))