From a05c61d3256fd9fee1e832987f01a14b03d391a1 Mon Sep 17 00:00:00 2001 From: phact Date: Wed, 30 Jul 2025 22:46:23 -0400 Subject: [PATCH] rm old app.py --- src/app.py | 1277 ---------------------------------------------------- 1 file changed, 1277 deletions(-) delete mode 100644 src/app.py diff --git a/src/app.py b/src/app.py deleted file mode 100644 index 1056fe0d..00000000 --- a/src/app.py +++ /dev/null @@ -1,1277 +0,0 @@ -# app.py -import datetime -import os -from collections import defaultdict -from typing import Any -import uuid -import time -import random -from dataclasses import dataclass, field -from enum import Enum -from concurrent.futures import ProcessPoolExecutor -import multiprocessing - -from agent import async_chat, async_langflow - -# Import connector components -from connectors.service import ConnectorService -from connectors.google_drive import GoogleDriveConnector -from session_manager import SessionManager -from auth_middleware import require_auth, optional_auth - -import hashlib -import tempfile -import asyncio - -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import JSONResponse, StreamingResponse -from starlette.routing import Route - -import aiofiles -from opensearchpy import AsyncOpenSearch -from opensearchpy._async.http_aiohttp import AIOHttpConnection -from docling.document_converter import DocumentConverter -from agentd.patch import patch_openai_with_mcp -from openai import AsyncOpenAI -from agentd.tool_decorator import tool -from dotenv import load_dotenv - -load_dotenv() -load_dotenv("../") - -import torch -print("CUDA available:", torch.cuda.is_available()) -print("CUDA version PyTorch was built with:", torch.version.cuda) - -# Initialize Docling converter -converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc. - -# Initialize Async OpenSearch (adjust hosts/auth as needed) -opensearch_host = os.getenv("OPENSEARCH_HOST", "localhost") -opensearch_port = int(os.getenv("OPENSEARCH_PORT", "9200")) -opensearch_username = os.getenv("OPENSEARCH_USERNAME", "admin") -opensearch_password = os.getenv("OPENSEARCH_PASSWORD") -langflow_url = os.getenv("LANGFLOW_URL", "http://localhost:7860") -flow_id = os.getenv("FLOW_ID") -langflow_key = os.getenv("LANGFLOW_SECRET_KEY") - - - -opensearch = AsyncOpenSearch( - hosts=[{"host": opensearch_host, "port": opensearch_port}], - connection_class=AIOHttpConnection, - scheme="https", - use_ssl=True, - verify_certs=False, - ssl_assert_fingerprint=None, - http_auth=(opensearch_username, opensearch_password), - http_compress=True, -) - -INDEX_NAME = "documents" -VECTOR_DIM = 1536 # e.g. text-embedding-3-small output size -EMBED_MODEL = "text-embedding-3-small" -index_body = { - "settings": { - "index": {"knn": True}, - "number_of_shards": 1, - "number_of_replicas": 1 - }, - "mappings": { - "properties": { - "document_id": { "type": "keyword" }, - "filename": { "type": "keyword" }, - "mimetype": { "type": "keyword" }, - "page": { "type": "integer" }, - "text": { "type": "text" }, - "chunk_embedding": { - "type": "knn_vector", - "dimension": VECTOR_DIM, - "method": { - "name": "disk_ann", - "engine": "jvector", - "space_type": "l2", - "parameters": { - "ef_construction": 100, - "m": 16 - } - } - }, - # Connector and source information - "source_url": { "type": "keyword" }, - "connector_type": { "type": "keyword" }, - # ACL fields - "owner": { "type": "keyword" }, - "allowed_users": { "type": "keyword" }, - "allowed_groups": { "type": "keyword" }, - "user_permissions": { "type": "object" }, - "group_permissions": { "type": "object" }, - # Timestamps - "created_time": { "type": "date" }, - "modified_time": { "type": "date" }, - "indexed_time": { "type": "date" }, - # Additional metadata - "metadata": { "type": "object" } - } - } -} - -langflow_client = AsyncOpenAI( - base_url=f"{langflow_url}/api/v1", - api_key=langflow_key -) -patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back - -# Initialize connector service -connector_service = ConnectorService( - opensearch_client=opensearch, - patched_async_client=patched_async_client, - process_pool=None, # Will be set after process_pool is initialized - embed_model=EMBED_MODEL, - index_name=INDEX_NAME -) - -# Initialize session manager -session_secret = os.getenv("SESSION_SECRET", "your-secret-key-change-in-production") -session_manager = SessionManager(session_secret) - -# Track used authorization codes to prevent duplicate usage -used_auth_codes = set() - -class TaskStatus(Enum): - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - -@dataclass -class FileTask: - file_path: str - status: TaskStatus = TaskStatus.PENDING - result: dict = None - error: str = None - retry_count: int = 0 - created_at: float = field(default_factory=time.time) - updated_at: float = field(default_factory=time.time) - -@dataclass -class UploadTask: - task_id: str - total_files: int - processed_files: int = 0 - successful_files: int = 0 - failed_files: int = 0 - file_tasks: dict = field(default_factory=dict) - status: TaskStatus = TaskStatus.PENDING - created_at: float = field(default_factory=time.time) - updated_at: float = field(default_factory=time.time) - -task_store = {} # user_id -> {task_id -> UploadTask} -background_tasks = set() - -# GPU device detection -def detect_gpu_devices(): - """Detect if GPU devices are actually available""" - try: - import torch - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - return True, torch.cuda.device_count() - except ImportError: - pass - - try: - import subprocess - result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) - if result.returncode == 0: - return True, "detected" - except (subprocess.SubprocessError, FileNotFoundError): - pass - - return False, 0 - -# GPU and concurrency configuration -HAS_GPU_DEVICES, GPU_COUNT = detect_gpu_devices() - -if HAS_GPU_DEVICES: - # GPU mode with actual GPU devices: Lower concurrency due to memory constraints - DEFAULT_WORKERS = min(4, multiprocessing.cpu_count() // 2) - print(f"GPU mode enabled with {GPU_COUNT} GPU(s) - using limited concurrency ({DEFAULT_WORKERS} workers)") -elif HAS_GPU_DEVICES: - # GPU mode requested but no devices found: Use full CPU concurrency - DEFAULT_WORKERS = multiprocessing.cpu_count() - print(f"GPU mode requested but no GPU devices found - falling back to full CPU concurrency ({DEFAULT_WORKERS} workers)") -else: - # CPU mode: Higher concurrency since no GPU memory constraints - DEFAULT_WORKERS = multiprocessing.cpu_count() - print(f"CPU-only mode enabled - using full concurrency ({DEFAULT_WORKERS} workers)") - -MAX_WORKERS = int(os.getenv("MAX_WORKERS", DEFAULT_WORKERS)) -process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS) -connector_service.process_pool = process_pool # Set the process pool for connector service - -print(f"Process pool initialized with {MAX_WORKERS} workers") - -# Global converter cache for worker processes -_worker_converter = None - -def get_worker_converter(): - """Get or create a DocumentConverter instance for this worker process""" - global _worker_converter - if _worker_converter is None: - import os - from docling.document_converter import DocumentConverter - - # Configure GPU settings for this worker - has_gpu_devices, _ = detect_gpu_devices() - if not has_gpu_devices: - # Force CPU-only mode in subprocess - os.environ['USE_CPU_ONLY'] = 'true' - os.environ['CUDA_VISIBLE_DEVICES'] = '' - os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' - os.environ['TRANSFORMERS_OFFLINE'] = '0' - os.environ['TORCH_USE_CUDA_DSA'] = '0' - - # Try to disable CUDA in torch if available - try: - import torch - torch.cuda.is_available = lambda: False - except ImportError: - pass - else: - # GPU mode - let libraries use GPU if available - os.environ.pop('USE_CPU_ONLY', None) - os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars - - print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})") - _worker_converter = DocumentConverter() - print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})") - - return _worker_converter - -def detect_gpu_devices(): - """Detect if GPU devices are actually available""" - try: - import torch - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - return True, torch.cuda.device_count() - except ImportError: - pass - - try: - import subprocess - result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) - if result.returncode == 0: - return True, "detected" - except (subprocess.SubprocessError, FileNotFoundError): - pass - - return False, 0 - -def process_document_sync(file_path: str): - """Synchronous document processing function for multiprocessing""" - import hashlib - from collections import defaultdict - - # Get the cached converter for this worker - converter = get_worker_converter() - - # Compute file hash - sha256 = hashlib.sha256() - with open(file_path, "rb") as f: - while True: - chunk = f.read(1 << 20) - if not chunk: - break - sha256.update(chunk) - file_hash = sha256.hexdigest() - - # Convert with docling - result = converter.convert(file_path) - full_doc = result.document.export_to_dict() - - # Extract relevant content (same logic as extract_relevant) - origin = full_doc.get("origin", {}) - texts = full_doc.get("texts", []) - - page_texts = defaultdict(list) - for txt in texts: - prov = txt.get("prov", []) - page_no = prov[0].get("page_no") if prov else None - if page_no is not None: - page_texts[page_no].append(txt.get("text", "").strip()) - - chunks = [] - for page in sorted(page_texts): - joined = "\n".join(page_texts[page]) - chunks.append({ - "page": page, - "text": joined - }) - - return { - "id": file_hash, - "filename": origin.get("filename"), - "mimetype": origin.get("mimetype"), - "chunks": chunks, - "file_path": file_path - } - -async def wait_for_opensearch(): - """Wait for OpenSearch to be ready with retries""" - max_retries = 30 - retry_delay = 2 - - for attempt in range(max_retries): - try: - await opensearch.info() - print("OpenSearch is ready!") - return - except Exception as e: - print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - else: - raise Exception("OpenSearch failed to become ready") - -async def init_index(): - await wait_for_opensearch() - - if not await opensearch.indices.exists(index=INDEX_NAME): - await opensearch.indices.create(index=INDEX_NAME, body=index_body) - print(f"Created index '{INDEX_NAME}'") - else: - print(f"Index '{INDEX_NAME}' already exists, skipping creation.") - -from collections import defaultdict - - -def extract_relevant(doc_dict: dict) -> dict: - """ - Given the full export_to_dict() result: - - Grabs origin metadata (hash, filename, mimetype) - - Finds every text fragment in `texts`, groups them by page_no - - Flattens tables in `tables` into tab-separated text, grouping by row - - Concatenates each page’s fragments and each table into its own chunk - Returns a slimmed dict ready for indexing, with each chunk under "text". - """ - origin = doc_dict.get("origin", {}) - chunks = [] - - # 1) process free-text fragments - page_texts = defaultdict(list) - for txt in doc_dict.get("texts", []): - prov = txt.get("prov", []) - page_no = prov[0].get("page_no") if prov else None - if page_no is not None: - page_texts[page_no].append(txt.get("text", "").strip()) - - for page in sorted(page_texts): - chunks.append({ - "page": page, - "type": "text", - "text": "\n".join(page_texts[page]) - }) - - # 2) process tables - for t_idx, table in enumerate(doc_dict.get("tables", [])): - prov = table.get("prov", []) - page_no = prov[0].get("page_no") if prov else None - - # group cells by their row index - rows = defaultdict(list) - for cell in table.get("data").get("table_cells", []): - r = cell.get("start_row_offset_idx") - c = cell.get("start_col_offset_idx") - text = cell.get("text", "").strip() - rows[r].append((c, text)) - - # build a tab‑separated line for each row, in order - flat_rows = [] - for r in sorted(rows): - cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])] - flat_rows.append("\t".join(cells)) - - chunks.append({ - "page": page_no, - "type": "table", - "table_index": t_idx, - "text": "\n".join(flat_rows) - }) - - return { - "id": origin.get("binary_hash"), - "filename": origin.get("filename"), - "mimetype": origin.get("mimetype"), - "chunks": chunks - } - -async def exponential_backoff_delay(retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None: - """Apply exponential backoff with jitter""" - delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay) - await asyncio.sleep(delay) - -async def process_file_with_retry(file_path: str, max_retries: int = 3) -> dict: - """Process a file with retry logic - retries everything up to max_retries times""" - last_error = None - - for attempt in range(max_retries + 1): - try: - return await process_file_common(file_path) - except Exception as e: - last_error = e - if attempt < max_retries: - await exponential_backoff_delay(attempt) - continue - else: - raise last_error - -async def process_file_common(file_path: str, file_hash: str = None, owner_user_id: str = None): - """ - Common processing logic for both upload and upload_path. - 1. Optionally compute SHA256 hash if not provided. - 2. Convert with docling and extract relevant content. - 3. Add embeddings. - 4. Index into OpenSearch. - """ - if file_hash is None: - sha256 = hashlib.sha256() - async with aiofiles.open(file_path, "rb") as f: - while True: - chunk = await f.read(1 << 20) - if not chunk: - break - sha256.update(chunk) - file_hash = sha256.hexdigest() - - exists = await opensearch.exists(index=INDEX_NAME, id=file_hash) - if exists: - return {"status": "unchanged", "id": file_hash} - - # convert and extract - # TODO: Check if docling can handle in-memory bytes instead of file path - # This would eliminate the need for temp files in upload flow - result = converter.convert(file_path) - full_doc = result.document.export_to_dict() - slim_doc = extract_relevant(full_doc) - - texts = [c["text"] for c in slim_doc["chunks"]] - resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) - embeddings = [d.embedding for d in resp.data] - - # Index each chunk as a separate document - for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): - chunk_doc = { - "document_id": file_hash, - "filename": slim_doc["filename"], - "mimetype": slim_doc["mimetype"], - "page": chunk["page"], - "text": chunk["text"], - "chunk_embedding": vect, - "owner": owner_user_id, # User who uploaded/owns this document - "indexed_time": datetime.datetime.now().isoformat() - } - chunk_id = f"{file_hash}_{i}" - await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) - return {"status": "indexed", "id": file_hash} - -async def process_file_on_disk(path: str): - """ - Process a file already on disk. - """ - result = await process_file_common(path) - result["path"] = path - return result - -async def process_single_file_task(upload_task: UploadTask, file_path: str) -> None: - """Process a single file and update task tracking""" - file_task = upload_task.file_tasks[file_path] - file_task.status = TaskStatus.RUNNING - file_task.updated_at = time.time() - - try: - # Check if file already exists in index - import asyncio - loop = asyncio.get_event_loop() - - # Run CPU-intensive docling processing in separate process - slim_doc = await loop.run_in_executor(process_pool, process_document_sync, file_path) - - # Check if already indexed - exists = await opensearch.exists(index=INDEX_NAME, id=slim_doc["id"]) - if exists: - result = {"status": "unchanged", "id": slim_doc["id"]} - else: - # Generate embeddings and index (I/O bound, keep in main process) - texts = [c["text"] for c in slim_doc["chunks"]] - resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) - embeddings = [d.embedding for d in resp.data] - - # Index each chunk - for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): - chunk_doc = { - "document_id": slim_doc["id"], - "filename": slim_doc["filename"], - "mimetype": slim_doc["mimetype"], - "page": chunk["page"], - "text": chunk["text"], - "chunk_embedding": vect - } - chunk_id = f"{slim_doc['id']}_{i}" - await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) - - result = {"status": "indexed", "id": slim_doc["id"]} - - result["path"] = file_path - file_task.status = TaskStatus.COMPLETED - file_task.result = result - upload_task.successful_files += 1 - - except Exception as e: - print(f"[ERROR] Failed to process file {file_path}: {e}") - import traceback - traceback.print_exc() - file_task.status = TaskStatus.FAILED - file_task.error = str(e) - upload_task.failed_files += 1 - finally: - file_task.updated_at = time.time() - upload_task.processed_files += 1 - upload_task.updated_at = time.time() - - if upload_task.processed_files >= upload_task.total_files: - upload_task.status = TaskStatus.COMPLETED - -async def background_upload_processor(user_id: str, task_id: str) -> None: - """Background task to process all files in an upload job with concurrency control""" - try: - upload_task = task_store[user_id][task_id] - upload_task.status = TaskStatus.RUNNING - upload_task.updated_at = time.time() - - # Process files with limited concurrency to avoid overwhelming the system - semaphore = asyncio.Semaphore(MAX_WORKERS * 2) # Allow 2x process pool size for async I/O - - async def process_with_semaphore(file_path: str): - async with semaphore: - await process_single_file_task(upload_task, file_path) - - tasks = [ - process_with_semaphore(file_path) - for file_path in upload_task.file_tasks.keys() - ] - - await asyncio.gather(*tasks, return_exceptions=True) - - except Exception as e: - print(f"[ERROR] Background upload processor failed for task {task_id}: {e}") - import traceback - traceback.print_exc() - if user_id in task_store and task_id in task_store[user_id]: - task_store[user_id][task_id].status = TaskStatus.FAILED - task_store[user_id][task_id].updated_at = time.time() - -@require_auth(session_manager) -async def upload(request: Request): - form = await request.form() - upload_file = form["file"] - - sha256 = hashlib.sha256() - tmp = tempfile.NamedTemporaryFile(delete=False) - try: - while True: - chunk = await upload_file.read(1 << 20) - if not chunk: - break - sha256.update(chunk) - tmp.write(chunk) - tmp.flush() - - file_hash = sha256.hexdigest() - exists = await opensearch.exists(index=INDEX_NAME, id=file_hash) - if exists: - return JSONResponse({"status": "unchanged", "id": file_hash}) - - user = request.state.user - result = await process_file_common(tmp.name, file_hash, owner_user_id=user.user_id) - return JSONResponse(result) - - finally: - tmp.close() - os.remove(tmp.name) - -@require_auth(session_manager) -async def upload_path(request: Request): - payload = await request.json() - base_dir = payload.get("path") - if not base_dir or not os.path.isdir(base_dir): - return JSONResponse({"error": "Invalid path"}, status_code=400) - - file_paths = [os.path.join(root, fn) - for root, _, files in os.walk(base_dir) - for fn in files] - - if not file_paths: - return JSONResponse({"error": "No files found in directory"}, status_code=400) - - task_id = str(uuid.uuid4()) - upload_task = UploadTask( - task_id=task_id, - total_files=len(file_paths), - file_tasks={path: FileTask(file_path=path) for path in file_paths} - ) - - user = request.state.user - if user.user_id not in task_store: - task_store[user.user_id] = {} - task_store[user.user_id][task_id] = upload_task - - background_task = asyncio.create_task(background_upload_processor(user.user_id, task_id)) - background_tasks.add(background_task) - background_task.add_done_callback(background_tasks.discard) - - return JSONResponse({ - "task_id": task_id, - "total_files": len(file_paths), - "status": "accepted" - }, status_code=201) - -@require_auth(session_manager) -async def upload_context(request: Request): - """Upload a file and add its content as context to the current conversation""" - import io - from docling_core.types.io import DocumentStream - - form = await request.form() - upload_file = form["file"] - filename = upload_file.filename or "uploaded_document" - - # Get optional parameters - previous_response_id = form.get("previous_response_id") - endpoint = form.get("endpoint", "langflow") # default to langflow - - # Stream file content into BytesIO - content = io.BytesIO() - while True: - chunk = await upload_file.read(1 << 20) # 1MB chunks - if not chunk: - break - content.write(chunk) - content.seek(0) # Reset to beginning for reading - - # Create DocumentStream and process with docling - doc_stream = DocumentStream(name=filename, stream=content) - result = converter.convert(doc_stream) - full_doc = result.document.export_to_dict() - slim_doc = extract_relevant(full_doc) - - # Extract all text content - all_text = [] - for chunk in slim_doc["chunks"]: - all_text.append(f"Page {chunk['page']}:\n{chunk['text']}") - - full_content = "\n\n".join(all_text) - - # Send document content as user message to get proper response_id - document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{full_content}\n\nPlease confirm you've received this document and are ready to answer questions about it." - - if endpoint == "langflow": - from agent import async_langflow - response_text, response_id = await async_langflow(langflow_client, flow_id, document_prompt, previous_response_id=previous_response_id) - else: # chat - from agent import async_chat - response_text, response_id = await async_chat(patched_async_client, document_prompt, previous_response_id=previous_response_id) - - response_data = { - "status": "context_added", - "filename": filename, - "pages": len(slim_doc["chunks"]), - "content_length": len(full_content), - "response_id": response_id, - "confirmation": response_text - } - - return JSONResponse(response_data) - -@require_auth(session_manager) -async def task_status(request: Request): - """Get the status of an upload task""" - task_id = request.path_params.get("task_id") - - user = request.state.user - - if (not task_id or - user.user_id not in task_store or - task_id not in task_store[user.user_id]): - return JSONResponse({"error": "Task not found"}, status_code=404) - - upload_task = task_store[user.user_id][task_id] - - file_statuses = {} - for file_path, file_task in upload_task.file_tasks.items(): - file_statuses[file_path] = { - "status": file_task.status.value, - "result": file_task.result, - "error": file_task.error, - "retry_count": file_task.retry_count, - "created_at": file_task.created_at, - "updated_at": file_task.updated_at - } - - return JSONResponse({ - "task_id": upload_task.task_id, - "status": upload_task.status.value, - "total_files": upload_task.total_files, - "processed_files": upload_task.processed_files, - "successful_files": upload_task.successful_files, - "failed_files": upload_task.failed_files, - "created_at": upload_task.created_at, - "updated_at": upload_task.updated_at, - "files": file_statuses - }) - -@require_auth(session_manager) -async def search(request: Request): - payload = await request.json() - query = payload.get("query") - if not query: - return JSONResponse({"error": "Query is required"}, status_code=400) - - user = request.state.user - return JSONResponse(await search_tool(query, user_id=user.user_id)) - - -@tool -async def search_tool(query: str, user_id: str = None)-> dict[str, Any]: - """ - Use this tool to search for documents relevant to the query. - - Args: - query (str): query string to search the corpus - user_id (str): user ID for access control (optional) - - Returns: - dict (str, Any): {"results": [chunks]} on success - """ - # Embed the query - resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query]) - query_embedding = resp.data[0].embedding - - # Base query structure - search_body = { - "query": { - "bool": { - "must": [ - { - "knn": { - "chunk_embedding": { - "vector": query_embedding, - "k": 10 - } - } - } - ] - } - }, - "_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"], - "size": 10 - } - - # Require authentication - no anonymous access to search - if not user_id: - return {"results": [], "error": "Authentication required"} - - # Authenticated user access control - # User can access documents if: - # 1. They own the document (owner field matches user_id) - # 2. They're in allowed_users list - # 3. Document has no ACL (public documents) - # TODO: Add group access control later - should_clauses = [ - {"term": {"owner": user_id}}, - {"term": {"allowed_users": user_id}}, - {"bool": {"must_not": {"exists": {"field": "owner"}}}} # Public docs - ] - - search_body["query"]["bool"]["should"] = should_clauses - search_body["query"]["bool"]["minimum_should_match"] = 1 - - results = await opensearch.search(index=INDEX_NAME, body=search_body) - - # Transform results - chunks = [] - for hit in results["hits"]["hits"]: - chunks.append({ - "filename": hit["_source"]["filename"], - "mimetype": hit["_source"]["mimetype"], - "page": hit["_source"]["page"], - "text": hit["_source"]["text"], - "score": hit["_score"], - "source_url": hit["_source"].get("source_url"), - "owner": hit["_source"].get("owner") - }) - return {"results": chunks} - -@require_auth(session_manager) -async def chat_endpoint(request): - data = await request.json() - prompt = data.get("prompt", "") - previous_response_id = data.get("previous_response_id") - stream = data.get("stream", False) - - # Get authenticated user - user = request.state.user - user_id = user.user_id - - if not prompt: - return JSONResponse({"error": "Prompt is required"}, status_code=400) - - if stream: - from agent import async_chat_stream - return StreamingResponse( - async_chat_stream(patched_async_client, prompt, user_id, previous_response_id=previous_response_id), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "Cache-Control" - } - ) - else: - response_text, response_id = await async_chat(patched_async_client, prompt, user_id, previous_response_id=previous_response_id) - response_data = {"response": response_text} - if response_id: - response_data["response_id"] = response_id - return JSONResponse(response_data) - -@require_auth(session_manager) -async def langflow_endpoint(request): - data = await request.json() - prompt = data.get("prompt", "") - previous_response_id = data.get("previous_response_id") - stream = data.get("stream", False) - - if not prompt: - return JSONResponse({"error": "Prompt is required"}, status_code=400) - - if not langflow_url or not flow_id or not langflow_key: - return JSONResponse({"error": "LANGFLOW_URL, FLOW_ID, and LANGFLOW_KEY environment variables are required"}, status_code=500) - - try: - if stream: - from agent import async_langflow_stream - return StreamingResponse( - async_langflow_stream(langflow_client, flow_id, prompt, previous_response_id=previous_response_id), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "Cache-Control" - } - ) - else: - response_text, response_id = await async_langflow(langflow_client, flow_id, prompt, previous_response_id=previous_response_id) - response_data = {"response": response_text} - if response_id: - response_data["response_id"] = response_id - return JSONResponse(response_data) - - except Exception as e: - return JSONResponse({"error": f"Langflow request failed: {str(e)}"}, status_code=500) - - -# Authentication endpoints -@optional_auth(session_manager) # Allow both authenticated and non-authenticated users -async def auth_init(request: Request): - """Initialize OAuth flow for authentication or data source connection""" - try: - data = await request.json() - provider = data.get("provider") # "google", "microsoft", etc. - purpose = data.get("purpose", "data_source") # "app_auth" or "data_source" - connection_name = data.get("name", f"{provider}_{purpose}") - redirect_uri = data.get("redirect_uri") # Frontend provides this - - # Get user from authentication if available - user = getattr(request.state, 'user', None) - user_id = user.user_id if user else None - - if provider != "google": - return JSONResponse({"error": "Unsupported provider"}, status_code=400) - - if not redirect_uri: - return JSONResponse({"error": "redirect_uri is required"}, status_code=400) - - # Get OAuth client configuration from environment - google_client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID") - if not google_client_id: - return JSONResponse({"error": "Google OAuth client ID not configured"}, status_code=500) - - # Create connection configuration - token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json" - config = { - "client_id": google_client_id, - "token_file": token_file, - "provider": provider, - "purpose": purpose, - "redirect_uri": redirect_uri # Store redirect_uri for use in callback - } - - # Create connection in manager - # For data sources, use provider name (e.g. "google_drive") - # For app auth, connector_type doesn't matter since it gets deleted - connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth" - connection_id = await connector_service.connection_manager.create_connection( - connector_type=connector_type, - name=connection_name, - config=config, - user_id=user_id - ) - - # Return OAuth configuration for client-side flow - # Include both identity and data access scopes - scopes = [ - # Identity scopes (for app auth) - 'openid', - 'email', - 'profile', - # Data access scopes (for connectors) - 'https://www.googleapis.com/auth/drive.readonly', - 'https://www.googleapis.com/auth/drive.metadata.readonly' - ] - - oauth_config = { - "client_id": google_client_id, - "scopes": scopes, - "redirect_uri": redirect_uri, # Use the redirect_uri from frontend - "authorization_endpoint": - "https://accounts.google.com/o/oauth2/v2/auth", - "token_endpoint": - "https://oauth2.googleapis.com/token" - } - - return JSONResponse({ - "connection_id": connection_id, - "oauth_config": oauth_config - }) - - except Exception as e: - import traceback - traceback.print_exc() - return JSONResponse({"error": f"Failed to initialize OAuth: {str(e)}"}, status_code=500) - - -async def auth_callback(request: Request): - """Handle OAuth callback - exchange authorization code for tokens""" - try: - data = await request.json() - connection_id = data.get("connection_id") - authorization_code = data.get("authorization_code") - state = data.get("state") - - if not all([connection_id, authorization_code]): - return JSONResponse({"error": "Missing required parameters (connection_id, authorization_code)"}, status_code=400) - - # Check if authorization code has already been used - if authorization_code in used_auth_codes: - return JSONResponse({"error": "Authorization code already used"}, status_code=400) - - # Mark code as used to prevent duplicate requests - used_auth_codes.add(authorization_code) - - try: - # Get connection config - connection_config = await connector_service.connection_manager.get_connection(connection_id) - if not connection_config: - return JSONResponse({"error": "Connection not found"}, status_code=404) - - # Exchange authorization code for tokens - import httpx - - # Use the redirect_uri that was stored during auth_init - redirect_uri = connection_config.config.get("redirect_uri") - if not redirect_uri: - return JSONResponse({"error": "Redirect URI not found in connection config"}, status_code=400) - - token_url = "https://oauth2.googleapis.com/token" - - token_payload = { - "code": authorization_code, - "client_id": connection_config.config["client_id"], - "client_secret": os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"), # Need this for server-side - "redirect_uri": redirect_uri, - "grant_type": "authorization_code" - } - - async with httpx.AsyncClient() as client: - token_response = await client.post(token_url, data=token_payload) - - if token_response.status_code != 200: - raise Exception(f"Token exchange failed: {token_response.text}") - - token_data = token_response.json() - - # Store tokens in the token file - token_file_data = { - "token": token_data["access_token"], - "refresh_token": token_data.get("refresh_token"), - "scopes": [ - "openid", - "email", - "profile", - "https://www.googleapis.com/auth/drive.readonly", - "https://www.googleapis.com/auth/drive.metadata.readonly" - ] - } - - # Add expiry if provided - if token_data.get("expires_in"): - from datetime import datetime, timedelta - expiry = datetime.now() + timedelta(seconds=int(token_data["expires_in"])) - token_file_data["expiry"] = expiry.isoformat() - - # Save tokens to file - import json - token_file_path = connection_config.config["token_file"] - async with aiofiles.open(token_file_path, 'w') as f: - await f.write(json.dumps(token_file_data, indent=2)) - - # Route based on purpose - purpose = connection_config.config.get("purpose", "data_source") - - if purpose == "app_auth": - # Handle app authentication - create user session - jwt_token = await session_manager.create_user_session(token_data["access_token"]) - - if jwt_token: - # Get the user info to create a persistent Google Drive connection - user_info = await session_manager.get_user_info_from_token(token_data["access_token"]) - user_id = user_info["id"] if user_info else None - - if user_id: - # Convert the temporary auth connection to a persistent Google Drive connection - # Update the connection to be a data source connection with the user_id - await connector_service.connection_manager.update_connection( - connection_id=connection_id, - connector_type="google_drive", - name=f"Google Drive ({user_info.get('email', 'Unknown')})", - user_id=user_id, - config={ - **connection_config.config, - "purpose": "data_source", # Convert to data source - "user_email": user_info.get("email") - } - ) - - response = JSONResponse({ - "status": "authenticated", - "purpose": "app_auth", - "redirect": "/", # Redirect to home page instead of dashboard - "google_drive_connection_id": connection_id # Return connection ID for frontend - }) - else: - # Fallback: delete connection if we can't get user info - await connector_service.connection_manager.delete_connection(connection_id) - response = JSONResponse({ - "status": "authenticated", - "purpose": "app_auth", - "redirect": "/" - }) - - # Set JWT as HTTP-only cookie for security - response.set_cookie( - key="auth_token", - value=jwt_token, - httponly=True, - secure=False, # False for development/testing - samesite="lax", - max_age=7 * 24 * 60 * 60 # 7 days - ) - return response - else: - # Clean up connection if session creation failed - await connector_service.connection_manager.delete_connection(connection_id) - return JSONResponse({"error": "Failed to create user session"}, status_code=500) - else: - # Handle data source connection - keep the connection for syncing - return JSONResponse({ - "status": "authenticated", - "connection_id": connection_id, - "purpose": "data_source", - "connector_type": connection_config.connector_type - }) - - except Exception as e: - import traceback - traceback.print_exc() - return JSONResponse({"error": f"OAuth callback failed: {str(e)}"}, status_code=500) - except Exception as e: - import traceback - traceback.print_exc() - return JSONResponse({"error": f"Callback failed: {str(e)}"}, status_code=500) - - -@optional_auth(session_manager) -async def auth_me(request: Request): - """Get current user information""" - user = getattr(request.state, 'user', None) - - if user: - return JSONResponse({ - "authenticated": True, - "user": { - "user_id": user.user_id, - "email": user.email, - "name": user.name, - "picture": user.picture, - "provider": user.provider, - "last_login": user.last_login.isoformat() if user.last_login else None - } - }) - else: - return JSONResponse({ - "authenticated": False, - "user": None - }) - -@require_auth(session_manager) -async def auth_logout(request: Request): - """Logout user by clearing auth cookie""" - response = JSONResponse({ - "status": "logged_out", - "message": "Successfully logged out" - }) - - # Clear the auth cookie - response.delete_cookie( - key="auth_token", - httponly=True, - secure=False, # False for development/testing - samesite="lax" - ) - - return response - - -@require_auth(session_manager) -async def connector_sync(request: Request): - """Sync files from a connector connection""" - data = await request.json() - connection_id = data.get("connection_id") - max_files = data.get("max_files") - - if not connection_id: - return JSONResponse({"error": "connection_id is required"}, status_code=400) - - try: - print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}") - - # Verify user owns this connection - user = request.state.user - print(f"[DEBUG] User: {user.user_id}") - - connection_config = await connector_service.connection_manager.get_connection(connection_id) - print(f"[DEBUG] Got connection config: {connection_config is not None}") - - if not connection_config: - return JSONResponse({"error": "Connection not found"}, status_code=404) - - if connection_config.user_id != user.user_id: - return JSONResponse({"error": "Access denied"}, status_code=403) - - print(f"[DEBUG] About to call sync_connector_files") - task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files) - print(f"[DEBUG] Got task_id: {task_id}") - - return JSONResponse({ - "task_id": task_id, - "status": "sync_started", - "message": f"Started syncing files from connection {connection_id}" - }, - status_code=201 - ) - - except Exception as e: - import sys - import traceback - - error_msg = f"[ERROR] Connector sync failed: {str(e)}" - print(error_msg, file=sys.stderr, flush=True) - traceback.print_exc(file=sys.stderr) - sys.stderr.flush() - - return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500) - - -@require_auth(session_manager) -async def connector_status(request: Request): - """Get connector status for authenticated user""" - connector_type = request.path_params.get("connector_type", "google_drive") - user = request.state.user - - # Get connections for this connector type and user - connections = await connector_service.connection_manager.list_connections( - user_id=user.user_id, - connector_type=connector_type - ) - - # Check if there are any active connections - active_connections = [conn for conn in connections if conn.is_active] - has_authenticated_connection = len(active_connections) > 0 - - return JSONResponse({ - "connector_type": connector_type, - "authenticated": has_authenticated_connection, # For frontend compatibility - "status": "connected" if has_authenticated_connection else "not_connected", - "connections": [ - { - "connection_id": conn.connection_id, - "name": conn.name, - "is_active": conn.is_active, - "created_at": conn.created_at.isoformat(), - "last_sync": conn.last_sync.isoformat() if conn.last_sync else None - } - for conn in connections - ] - }) - - -app = Starlette(debug=True, routes=[ - Route("/upload", upload, methods=["POST"]), - Route("/upload_context", upload_context, methods=["POST"]), - Route("/upload_path", upload_path, methods=["POST"]), - Route("/tasks/{task_id}", task_status, methods=["GET"]), - Route("/search", search, methods=["POST"]), - Route("/chat", chat_endpoint, methods=["POST"]), - Route("/langflow", langflow_endpoint, methods=["POST"]), - # Authentication endpoints - Route("/auth/init", auth_init, methods=["POST"]), - Route("/auth/callback", auth_callback, methods=["POST"]), - Route("/auth/me", auth_me, methods=["GET"]), - Route("/auth/logout", auth_logout, methods=["POST"]), - Route("/connectors/sync", connector_sync, methods=["POST"]), - Route("/connectors/status/{connector_type}", connector_status, methods=["GET"]), -]) - -if __name__ == "__main__": - import uvicorn - import atexit - - async def main(): - await init_index() - await connector_service.initialize() - - # Cleanup process pool on exit - def cleanup(): - process_pool.shutdown(wait=True) - - atexit.register(cleanup) - - asyncio.run(main()) - uvicorn.run( - "app:app", - host="0.0.0.0", - port=8000, - reload=True, - )