move client / secret env vars to connector
This commit is contained in:
parent
40cac9950c
commit
02e39286fc
4 changed files with 90 additions and 25 deletions
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Dict, List, Any, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,10 +50,36 @@ class ConnectorDocument:
|
|||
class BaseConnector(ABC):
|
||||
"""Base class for all document connectors"""
|
||||
|
||||
# Each connector must define the environment variable names for OAuth credentials
|
||||
CLIENT_ID_ENV_VAR: str = None
|
||||
CLIENT_SECRET_ENV_VAR: str = None
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self._authenticated = False
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
"""Get the OAuth client ID from environment variable"""
|
||||
if not self.CLIENT_ID_ENV_VAR:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must define CLIENT_ID_ENV_VAR")
|
||||
|
||||
client_id = os.getenv(self.CLIENT_ID_ENV_VAR)
|
||||
if not client_id:
|
||||
raise ValueError(f"Environment variable {self.CLIENT_ID_ENV_VAR} is not set")
|
||||
|
||||
return client_id
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Get the OAuth client secret from environment variable"""
|
||||
if not self.CLIENT_SECRET_ENV_VAR:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must define CLIENT_SECRET_ENV_VAR")
|
||||
|
||||
secret = os.getenv(self.CLIENT_SECRET_ENV_VAR)
|
||||
if not secret:
|
||||
raise ValueError(f"Environment variable {self.CLIENT_SECRET_ENV_VAR} is not set")
|
||||
|
||||
return secret
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self) -> bool:
|
||||
"""Authenticate with the service"""
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from .oauth import GoogleDriveOAuth
|
|||
# Global worker service cache for process pools
|
||||
_worker_drive_service = None
|
||||
|
||||
def get_worker_drive_service(client_id: str, token_file: str):
|
||||
def get_worker_drive_service(client_id: str, client_secret: str, token_file: str):
|
||||
"""Get or create a Google Drive service instance for this worker process"""
|
||||
global _worker_drive_service
|
||||
if _worker_drive_service is None:
|
||||
|
|
@ -23,7 +23,7 @@ def get_worker_drive_service(client_id: str, token_file: str):
|
|||
|
||||
# Create OAuth instance and load credentials in worker
|
||||
from .oauth import GoogleDriveOAuth
|
||||
oauth = GoogleDriveOAuth(client_id=client_id, token_file=token_file)
|
||||
oauth = GoogleDriveOAuth(client_id=client_id, client_secret=client_secret, token_file=token_file)
|
||||
|
||||
# Load credentials synchronously in worker
|
||||
import asyncio
|
||||
|
|
@ -40,9 +40,9 @@ def get_worker_drive_service(client_id: str, token_file: str):
|
|||
|
||||
|
||||
# Module-level functions for process pool execution (must be pickleable)
|
||||
def _sync_list_files_worker(client_id, token_file, query, page_token, page_size):
|
||||
def _sync_list_files_worker(client_id, client_secret, token_file, query, page_token, page_size):
|
||||
"""Worker function for listing files in process pool"""
|
||||
service = get_worker_drive_service(client_id, token_file)
|
||||
service = get_worker_drive_service(client_id, client_secret, token_file)
|
||||
return service.files().list(
|
||||
q=query,
|
||||
pageSize=page_size,
|
||||
|
|
@ -51,16 +51,16 @@ def _sync_list_files_worker(client_id, token_file, query, page_token, page_size)
|
|||
).execute()
|
||||
|
||||
|
||||
def _sync_get_metadata_worker(client_id, token_file, file_id):
|
||||
def _sync_get_metadata_worker(client_id, client_secret, token_file, file_id):
|
||||
"""Worker function for getting file metadata in process pool"""
|
||||
service = get_worker_drive_service(client_id, token_file)
|
||||
service = get_worker_drive_service(client_id, client_secret, token_file)
|
||||
return service.files().get(
|
||||
fileId=file_id,
|
||||
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size"
|
||||
).execute()
|
||||
|
||||
|
||||
def _sync_download_worker(client_id, token_file, file_id, mime_type, file_size=None):
|
||||
def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_type, file_size=None):
|
||||
"""Worker function for downloading files in process pool"""
|
||||
import signal
|
||||
import time
|
||||
|
|
@ -91,7 +91,7 @@ def _sync_download_worker(client_id, token_file, file_id, mime_type, file_size=N
|
|||
signal.alarm(timeout_seconds)
|
||||
|
||||
try:
|
||||
service = get_worker_drive_service(client_id, token_file)
|
||||
service = get_worker_drive_service(client_id, client_secret, token_file)
|
||||
|
||||
# For Google native formats, export as PDF
|
||||
if mime_type.startswith('application/vnd.google-apps.'):
|
||||
|
|
@ -129,6 +129,10 @@ def _sync_download_worker(client_id, token_file, file_id, mime_type, file_size=N
|
|||
class GoogleDriveConnector(BaseConnector):
|
||||
"""Google Drive connector with OAuth and webhook support"""
|
||||
|
||||
# OAuth environment variables
|
||||
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
||||
|
||||
# Supported file types that can be processed by docling
|
||||
SUPPORTED_MIMETYPES = {
|
||||
'application/pdf',
|
||||
|
|
@ -148,7 +152,8 @@ class GoogleDriveConnector(BaseConnector):
|
|||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.oauth = GoogleDriveOAuth(
|
||||
client_id=config.get('client_id'),
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=config.get('token_file', 'gdrive_token.json')
|
||||
)
|
||||
self.service = None
|
||||
|
|
@ -230,6 +235,7 @@ class GoogleDriveConnector(BaseConnector):
|
|||
process_pool,
|
||||
_sync_list_files_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
query,
|
||||
page_token, # page_token should come before page_size
|
||||
|
|
@ -274,6 +280,7 @@ class GoogleDriveConnector(BaseConnector):
|
|||
process_pool,
|
||||
_sync_get_metadata_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
file_id
|
||||
)
|
||||
|
|
@ -319,6 +326,7 @@ class GoogleDriveConnector(BaseConnector):
|
|||
process_pool,
|
||||
_sync_download_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
file_id,
|
||||
mime_type,
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ class GoogleDriveOAuth:
|
|||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
def __init__(self, client_id: str = None, token_file: str = "token.json"):
|
||||
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.creds: Optional[Credentials] = None
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ class GoogleDriveOAuth:
|
|||
id_token=token_data.get('id_token'),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=self.client_id,
|
||||
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"), # Need for refresh
|
||||
client_secret=self.client_secret, # Need for refresh
|
||||
scopes=token_data.get('scopes', self.SCOPES)
|
||||
)
|
||||
|
||||
|
|
@ -54,15 +55,37 @@ class GoogleDriveOAuth:
|
|||
return self.creds
|
||||
|
||||
async def save_credentials(self):
|
||||
"""Save credentials to token file"""
|
||||
"""Save credentials to token file (without client_secret)"""
|
||||
if self.creds:
|
||||
# Create minimal token data without client_secret
|
||||
token_data = {
|
||||
"token": self.creds.token,
|
||||
"refresh_token": self.creds.refresh_token,
|
||||
"id_token": self.creds.id_token,
|
||||
"scopes": self.creds.scopes,
|
||||
}
|
||||
|
||||
# Add expiry if available
|
||||
if self.creds.expiry:
|
||||
token_data["expiry"] = self.creds.expiry.isoformat()
|
||||
|
||||
async with aiofiles.open(self.token_file, 'w') as f:
|
||||
await f.write(self.creds.to_json())
|
||||
await f.write(json.dumps(token_data, indent=2))
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
flow = Flow.from_client_secrets_file(
|
||||
self.credentials_file,
|
||||
# Create flow from client credentials directly
|
||||
client_config = {
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
}
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
client_config,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import aiofiles
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from config.settings import GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET, WEBHOOK_BASE_URL
|
||||
from config.settings import WEBHOOK_BASE_URL
|
||||
from session_manager import SessionManager
|
||||
|
||||
class AuthService:
|
||||
|
|
@ -24,13 +24,11 @@ class AuthService:
|
|||
if not redirect_uri:
|
||||
raise ValueError("redirect_uri is required")
|
||||
|
||||
if not GOOGLE_OAUTH_CLIENT_ID:
|
||||
raise ValueError("Google OAuth client ID not configured")
|
||||
# We'll validate client credentials when creating the connector
|
||||
|
||||
# Create connection configuration
|
||||
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||||
config = {
|
||||
"client_id": GOOGLE_OAUTH_CLIENT_ID,
|
||||
"token_file": token_file,
|
||||
"provider": provider,
|
||||
"purpose": purpose,
|
||||
|
|
@ -41,8 +39,8 @@ class AuthService:
|
|||
if WEBHOOK_BASE_URL:
|
||||
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{provider}_drive/webhook"
|
||||
|
||||
# Create connection in manager
|
||||
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
|
||||
# Create connection in manager (always use _drive connector type as it handles OAuth)
|
||||
connector_type = f"{provider}_drive"
|
||||
connection_id = await self.connector_service.connection_manager.create_connection(
|
||||
connector_type=connector_type,
|
||||
name=connection_name,
|
||||
|
|
@ -57,8 +55,14 @@ class AuthService:
|
|||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
# Get client_id from environment variable (same as connector would do)
|
||||
import os
|
||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
if not client_id:
|
||||
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable not set")
|
||||
|
||||
oauth_config = {
|
||||
"client_id": GOOGLE_OAUTH_CLIENT_ID,
|
||||
"client_id": client_id,
|
||||
"scopes": scopes,
|
||||
"redirect_uri": redirect_uri,
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
|
|
@ -95,10 +99,13 @@ class AuthService:
|
|||
raise ValueError("Redirect URI not found in connection config")
|
||||
|
||||
token_url = "https://oauth2.googleapis.com/token"
|
||||
# Get connector to access client credentials
|
||||
connector = self.connector_service.connection_manager._create_connector(connection_config)
|
||||
|
||||
token_payload = {
|
||||
"code": authorization_code,
|
||||
"client_id": connection_config.config["client_id"],
|
||||
"client_secret": GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
"client_id": connector.get_client_id(),
|
||||
"client_secret": connector.get_client_secret(),
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
|
|
@ -111,7 +118,7 @@ class AuthService:
|
|||
|
||||
token_data = token_response.json()
|
||||
|
||||
# Store tokens in the token file
|
||||
# Store tokens in the token file (without client_secret)
|
||||
token_file_data = {
|
||||
"token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue