sharepont / opendrive
This commit is contained in:
parent
c2992d7680
commit
c111295bac
26 changed files with 931 additions and 138 deletions
|
|
@ -7,6 +7,9 @@ OPENSEARCH_PASSWORD=OSisgendb1!
|
|||
# make here https://console.cloud.google.com/apis/credentials
|
||||
GOOGLE_OAUTH_CLIENT_ID=
|
||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||
# Azure app registration credentials for SharePoint/OneDrive
|
||||
MICROSOFT_GRAPH_OAUTH_CLIENT_ID=
|
||||
MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET=
|
||||
# Optional dns routable from google (etc.) to handle continous ingest (ngrok works)
|
||||
WEBHOOK_BASE_URL=
|
||||
OPENAI_API_KEY=
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client"
|
||||
|
||||
import { useState, useEffect, Suspense } from "react"
|
||||
import { useState, useEffect, useCallback, Suspense } from "react"
|
||||
import { useSearchParams } from "next/navigation"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
||||
|
|
@ -52,23 +52,43 @@ function ConnectorsPage() {
|
|||
const [syncResults, setSyncResults] = useState<{[key: string]: SyncResult | null}>({})
|
||||
const [maxFiles, setMaxFiles] = useState<number>(10)
|
||||
|
||||
// Function definitions first
|
||||
const checkConnectorStatuses = async () => {
|
||||
// Initialize connectors list
|
||||
setConnectors([
|
||||
{
|
||||
id: "google_drive",
|
||||
name: "Google Drive",
|
||||
description: "Connect your Google Drive to automatically sync documents",
|
||||
icon: <div className="w-8 h-8 bg-blue-500 rounded flex items-center justify-center text-white font-bold">G</div>,
|
||||
status: "not_connected",
|
||||
type: "google_drive"
|
||||
},
|
||||
])
|
||||
// Helper function to get connector icon
|
||||
const getConnectorIcon = (iconName: string) => {
|
||||
const iconMap: { [key: string]: React.ReactElement } = {
|
||||
'google-drive': <div className="w-8 h-8 bg-blue-500 rounded flex items-center justify-center text-white font-bold">G</div>,
|
||||
'sharepoint': <div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold">SP</div>,
|
||||
'onedrive': <div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold">OD</div>,
|
||||
}
|
||||
return iconMap[iconName] || <div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold">?</div>
|
||||
}
|
||||
|
||||
// Function definitions first
|
||||
const checkConnectorStatuses = useCallback(async () => {
|
||||
try {
|
||||
// Fetch available connectors from backend
|
||||
const connectorsResponse = await fetch('/api/connectors')
|
||||
if (!connectorsResponse.ok) {
|
||||
throw new Error('Failed to load connectors')
|
||||
}
|
||||
|
||||
const connectorsResult = await connectorsResponse.json()
|
||||
const connectorTypes = Object.keys(connectorsResult.connectors)
|
||||
|
||||
// Initialize connectors list with metadata from backend
|
||||
const initialConnectors = connectorTypes
|
||||
.filter(type => connectorsResult.connectors[type].available) // Only show available connectors
|
||||
.map(type => ({
|
||||
id: type,
|
||||
name: connectorsResult.connectors[type].name,
|
||||
description: connectorsResult.connectors[type].description,
|
||||
icon: getConnectorIcon(connectorsResult.connectors[type].icon),
|
||||
status: "not_connected" as const,
|
||||
type: type
|
||||
}))
|
||||
|
||||
setConnectors(initialConnectors)
|
||||
|
||||
// Check status for each connector type
|
||||
const connectorTypes = ["google_drive"]
|
||||
|
||||
for (const connectorType of connectorTypes) {
|
||||
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
||||
|
|
@ -92,7 +112,7 @@ function ConnectorsPage() {
|
|||
} catch (error) {
|
||||
console.error('Failed to check connector statuses:', error)
|
||||
}
|
||||
}
|
||||
}, [setConnectors])
|
||||
|
||||
const handleConnect = async (connector: Connector) => {
|
||||
setIsConnecting(connector.id)
|
||||
|
|
@ -110,8 +130,8 @@ function ConnectorsPage() {
|
|||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: connector.type.replace('_drive', ''), // "google_drive" -> "google"
|
||||
purpose: "data_source",
|
||||
connector_type: connector.type,
|
||||
purpose: "data_source",
|
||||
name: `${connector.name} Connection`,
|
||||
redirect_uri: redirectUri
|
||||
}),
|
||||
|
|
@ -262,7 +282,7 @@ function ConnectorsPage() {
|
|||
url.searchParams.delete('oauth_success')
|
||||
window.history.replaceState({}, '', url.toString())
|
||||
}
|
||||
}, [searchParams, isAuthenticated])
|
||||
}, [searchParams, isAuthenticated, checkConnectorStatuses])
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client"
|
||||
|
||||
import { useState, useEffect, Suspense } from "react"
|
||||
import { useState, useEffect, useCallback, Suspense } from "react"
|
||||
import { useSearchParams } from "next/navigation"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
||||
|
|
@ -150,27 +150,59 @@ function KnowledgeSourcesPage() {
|
|||
}
|
||||
}
|
||||
|
||||
// Connector functions
|
||||
const checkConnectorStatuses = async () => {
|
||||
setConnectors([
|
||||
{
|
||||
id: "google_drive",
|
||||
name: "Google Drive",
|
||||
description: "Connect your Google Drive to automatically sync documents",
|
||||
icon: (
|
||||
<div
|
||||
className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"
|
||||
>
|
||||
G
|
||||
</div>
|
||||
),
|
||||
status: "not_connected",
|
||||
type: "google_drive"
|
||||
},
|
||||
])
|
||||
// Helper function to get connector icon
|
||||
const getConnectorIcon = (iconName: string) => {
|
||||
const iconMap: { [key: string]: React.ReactElement } = {
|
||||
'google-drive': (
|
||||
<div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||
G
|
||||
</div>
|
||||
),
|
||||
'sharepoint': (
|
||||
<div className="w-8 h-8 bg-blue-700 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||
SP
|
||||
</div>
|
||||
),
|
||||
'onedrive': (
|
||||
<div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||
OD
|
||||
</div>
|
||||
),
|
||||
}
|
||||
return iconMap[iconName] || (
|
||||
<div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||
?
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Connector functions
|
||||
const checkConnectorStatuses = useCallback(async () => {
|
||||
try {
|
||||
const connectorTypes = ["google_drive"]
|
||||
// Fetch available connectors from backend
|
||||
const connectorsResponse = await fetch('/api/connectors')
|
||||
if (!connectorsResponse.ok) {
|
||||
throw new Error('Failed to load connectors')
|
||||
}
|
||||
|
||||
const connectorsResult = await connectorsResponse.json()
|
||||
const connectorTypes = Object.keys(connectorsResult.connectors)
|
||||
|
||||
// Initialize connectors list with metadata from backend
|
||||
const initialConnectors = connectorTypes
|
||||
.filter(type => connectorsResult.connectors[type].available) // Only show available connectors
|
||||
.map(type => ({
|
||||
id: type,
|
||||
name: connectorsResult.connectors[type].name,
|
||||
description: connectorsResult.connectors[type].description,
|
||||
icon: getConnectorIcon(connectorsResult.connectors[type].icon),
|
||||
status: "not_connected" as const,
|
||||
type: type
|
||||
}))
|
||||
|
||||
setConnectors(initialConnectors)
|
||||
|
||||
// Check status for each connector type
|
||||
|
||||
for (const connectorType of connectorTypes) {
|
||||
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
||||
|
|
@ -194,18 +226,27 @@ function KnowledgeSourcesPage() {
|
|||
} catch (error) {
|
||||
console.error('Failed to check connector statuses:', error)
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleConnect = async (connector: Connector) => {
|
||||
setIsConnecting(connector.id)
|
||||
setSyncResults(prev => ({ ...prev, [connector.id]: null }))
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/connectors/${connector.type}/connect`, {
|
||||
// Use the shared auth callback URL, same as connectors page
|
||||
const redirectUri = `${window.location.origin}/auth/callback`
|
||||
|
||||
const response = await fetch('/api/auth/init', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
connector_type: connector.type,
|
||||
purpose: "data_source",
|
||||
name: `${connector.name} Connection`,
|
||||
redirect_uri: redirectUri
|
||||
}),
|
||||
})
|
||||
|
||||
if (response.ok) {
|
||||
|
|
@ -305,7 +346,7 @@ function KnowledgeSourcesPage() {
|
|||
url.searchParams.delete('oauth_success')
|
||||
window.history.replaceState({}, '', url.toString())
|
||||
}
|
||||
}, [searchParams, isAuthenticated])
|
||||
}, [searchParams, isAuthenticated, checkConnectorStatuses])
|
||||
|
||||
// Fetch global stats using match-all wildcard
|
||||
const fetchStats = async () => {
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ export function AuthProvider({ children }: AuthProviderProps) {
|
|||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'google',
|
||||
connector_type: 'google_drive',
|
||||
purpose: 'app_auth',
|
||||
name: 'App Authentication',
|
||||
redirect_uri: redirectUri
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client"
|
||||
|
||||
import React, { createContext, useContext, useState, useEffect, ReactNode } from 'react'
|
||||
import React, { createContext, useContext, useState, ReactNode } from 'react'
|
||||
|
||||
interface KnowledgeFilter {
|
||||
id: string
|
||||
|
|
@ -61,9 +61,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
|||
const parsed = JSON.parse(filter.query_data) as ParsedQueryData
|
||||
setParsedFilterData(parsed)
|
||||
|
||||
// Store in localStorage for persistence across page reloads
|
||||
localStorage.setItem('selectedKnowledgeFilter', JSON.stringify(filter))
|
||||
|
||||
// Auto-open panel when filter is selected
|
||||
setIsPanelOpen(true)
|
||||
} catch (error) {
|
||||
|
|
@ -72,7 +69,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
|||
}
|
||||
} else {
|
||||
setParsedFilterData(null)
|
||||
localStorage.removeItem('selectedKnowledgeFilter')
|
||||
setIsPanelOpen(false)
|
||||
}
|
||||
}
|
||||
|
|
@ -93,19 +89,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
|||
setIsPanelOpen(false) // Close panel but keep filter selected
|
||||
}
|
||||
|
||||
// Load persisted filter on mount
|
||||
useEffect(() => {
|
||||
try {
|
||||
const saved = localStorage.getItem('selectedKnowledgeFilter')
|
||||
if (saved) {
|
||||
const filter = JSON.parse(saved) as KnowledgeFilter
|
||||
setSelectedFilter(filter)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading persisted filter:', error)
|
||||
localStorage.removeItem('selectedKnowledgeFilter')
|
||||
}
|
||||
}, [])
|
||||
|
||||
const value: KnowledgeFilterContextType = {
|
||||
selectedFilter,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ dependencies = [
|
|||
"google-api-python-client>=2.143.0",
|
||||
"google-auth-httplib2>=0.2.0",
|
||||
"google-auth-oauthlib>=1.2.0",
|
||||
"msal>=1.29.0",
|
||||
"httpx>=0.27.0",
|
||||
"opensearch-py[async]>=3.0.0",
|
||||
"pyjwt>=2.8.0",
|
||||
|
|
|
|||
|
|
@ -5,16 +5,16 @@ async def auth_init(request: Request, auth_service, session_manager):
|
|||
"""Initialize OAuth flow for authentication or data source connection"""
|
||||
try:
|
||||
data = await request.json()
|
||||
provider = data.get("provider")
|
||||
connector_type = data.get("connector_type")
|
||||
purpose = data.get("purpose", "data_source")
|
||||
connection_name = data.get("name", f"{provider}_{purpose}")
|
||||
connection_name = data.get("name", f"{connector_type}_{purpose}")
|
||||
redirect_uri = data.get("redirect_uri")
|
||||
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = user.user_id if user else None
|
||||
|
||||
result = await auth_service.init_oauth(
|
||||
provider, purpose, connection_name, redirect_uri, user_id
|
||||
connector_type, purpose, connection_name, redirect_uri, user_id
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,14 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.responses import JSONResponse, PlainTextResponse
|
||||
|
||||
async def list_connectors(request: Request, connector_service, session_manager):
|
||||
"""List available connector types with metadata"""
|
||||
try:
|
||||
connector_types = connector_service.connection_manager.get_available_connector_types()
|
||||
return JSONResponse({"connectors": connector_types})
|
||||
except Exception as e:
|
||||
print(f"Error listing connectors: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
async def connector_sync(request: Request, connector_service, session_manager):
|
||||
"""Sync files from all active connections of a connector type"""
|
||||
|
|
@ -85,7 +94,29 @@ async def connector_status(request: Request, connector_service, session_manager)
|
|||
async def connector_webhook(request: Request, connector_service, session_manager):
|
||||
"""Handle webhook notifications from any connector type"""
|
||||
connector_type = request.path_params.get("connector_type")
|
||||
|
||||
|
||||
# Handle webhook validation (connector-specific)
|
||||
temp_config = {"token_file": "temp.json"}
|
||||
from connectors.connection_manager import ConnectionConfig
|
||||
temp_connection = ConnectionConfig(
|
||||
connection_id="temp",
|
||||
connector_type=connector_type,
|
||||
name="temp",
|
||||
config=temp_config
|
||||
)
|
||||
try:
|
||||
temp_connector = connector_service.connection_manager._create_connector(temp_connection)
|
||||
validation_response = temp_connector.handle_webhook_validation(
|
||||
request.method,
|
||||
dict(request.headers),
|
||||
dict(request.query_params)
|
||||
)
|
||||
if validation_response:
|
||||
return PlainTextResponse(validation_response)
|
||||
except (NotImplementedError, ValueError):
|
||||
# Connector type not found or validation not needed
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get the raw payload and headers
|
||||
payload = {}
|
||||
|
|
@ -109,8 +140,13 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
|
||||
print(f"[WEBHOOK] {connector_type} notification received")
|
||||
|
||||
# Extract channel/subscription ID from headers (Google Drive specific)
|
||||
channel_id = headers.get('x-goog-channel-id')
|
||||
# Extract channel/subscription ID using connector-specific method
|
||||
try:
|
||||
temp_connector = connector_service.connection_manager._create_connector(temp_connection)
|
||||
channel_id = temp_connector.extract_webhook_channel_id(payload, headers)
|
||||
except (NotImplementedError, ValueError):
|
||||
channel_id = None
|
||||
|
||||
if not channel_id:
|
||||
print(f"[WEBHOOK] No channel ID found in {connector_type} webhook")
|
||||
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
|
||||
|
|
@ -132,7 +168,6 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
||||
if connector:
|
||||
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
||||
resource_id = headers.get('x-goog-resource-id')
|
||||
await connector.cleanup_subscription(channel_id, resource_id)
|
||||
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,27 +3,25 @@ from starlette.responses import JSONResponse
|
|||
|
||||
async def search(request: Request, search_service, session_manager):
|
||||
"""Search for documents"""
|
||||
payload = await request.json()
|
||||
query = payload.get("query")
|
||||
if not query:
|
||||
return JSONResponse({"error": "Query is required"}, status_code=400)
|
||||
|
||||
filters = payload.get("filters", {}) # Optional filters, defaults to empty dict
|
||||
limit = payload.get("limit", 10) # Optional limit, defaults to 10
|
||||
score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0
|
||||
|
||||
user = request.state.user
|
||||
# Extract JWT token from cookie for OpenSearch OIDC auth
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold)
|
||||
|
||||
# Return appropriate HTTP status codes
|
||||
if result.get("success"):
|
||||
try:
|
||||
payload = await request.json()
|
||||
query = payload.get("query")
|
||||
if not query:
|
||||
return JSONResponse({"error": "Query is required"}, status_code=400)
|
||||
|
||||
filters = payload.get("filters", {}) # Optional filters, defaults to empty dict
|
||||
limit = payload.get("limit", 10) # Optional limit, defaults to 10
|
||||
score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0
|
||||
|
||||
user = request.state.user
|
||||
# Extract JWT token from cookie for OpenSearch OIDC auth
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold)
|
||||
return JSONResponse(result, status_code=200)
|
||||
else:
|
||||
error_msg = result.get("error", "")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
||||
return JSONResponse(result, status_code=403)
|
||||
return JSONResponse({"error": error_msg}, status_code=403)
|
||||
else:
|
||||
return JSONResponse(result, status_code=500)
|
||||
return JSONResponse({"error": error_msg}, status_code=500)
|
||||
|
|
@ -4,22 +4,20 @@ from starlette.responses import JSONResponse
|
|||
|
||||
async def upload(request: Request, document_service, session_manager):
|
||||
"""Upload a single file"""
|
||||
form = await request.form()
|
||||
upload_file = form["file"]
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token)
|
||||
|
||||
# Return appropriate HTTP status codes
|
||||
if result.get("success"):
|
||||
try:
|
||||
form = await request.form()
|
||||
upload_file = form["file"]
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token)
|
||||
return JSONResponse(result, status_code=201) # Created
|
||||
else:
|
||||
error_msg = result.get("error", "")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
||||
return JSONResponse(result, status_code=403)
|
||||
return JSONResponse({"error": error_msg}, status_code=403)
|
||||
else:
|
||||
return JSONResponse(result, status_code=500)
|
||||
return JSONResponse({"error": error_msg}, status_code=500)
|
||||
|
||||
async def upload_path(request: Request, task_service, session_manager):
|
||||
"""Upload all files from a directory path"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from .base import BaseConnector
|
||||
from .google_drive import GoogleDriveConnector
|
||||
from .sharepoint import SharePointConnector
|
||||
from .onedrive import OneDriveConnector
|
||||
|
||||
__all__ = ["BaseConnector", "GoogleDriveConnector"]
|
||||
__all__ = ["BaseConnector", "GoogleDriveConnector", "SharePointConnector", "OneDriveConnector"]
|
||||
|
|
|
|||
|
|
@ -54,6 +54,11 @@ class BaseConnector(ABC):
|
|||
CLIENT_ID_ENV_VAR: str = None
|
||||
CLIENT_SECRET_ENV_VAR: str = None
|
||||
|
||||
# Connector metadata for UI
|
||||
CONNECTOR_NAME: str = None
|
||||
CONNECTOR_DESCRIPTION: str = None
|
||||
CONNECTOR_ICON: str = None # Icon identifier or emoji
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self._authenticated = False
|
||||
|
|
@ -105,6 +110,17 @@ class BaseConnector(ABC):
|
|||
"""Handle webhook notification. Returns list of affected file IDs."""
|
||||
pass
|
||||
|
||||
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||
"""Handle webhook validation (e.g., for subscription setup).
|
||||
Returns validation response if applicable, None otherwise.
|
||||
Default implementation returns None (no validation needed)."""
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract channel/subscription ID from webhook payload/headers.
|
||||
Must be implemented by each connector."""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement extract_webhook_channel_id")
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||
"""Clean up subscription"""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from pathlib import Path
|
|||
|
||||
from .base import BaseConnector
|
||||
from .google_drive import GoogleDriveConnector
|
||||
from .sharepoint import SharePointConnector
|
||||
from .onedrive import OneDriveConnector
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -186,10 +188,54 @@ class ConnectionManager:
|
|||
|
||||
return None
|
||||
|
||||
def get_available_connector_types(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Get available connector types with their metadata"""
|
||||
return {
|
||||
"google_drive": {
|
||||
"name": GoogleDriveConnector.CONNECTOR_NAME,
|
||||
"description": GoogleDriveConnector.CONNECTOR_DESCRIPTION,
|
||||
"icon": GoogleDriveConnector.CONNECTOR_ICON,
|
||||
"available": self._is_connector_available("google_drive")
|
||||
},
|
||||
"sharepoint": {
|
||||
"name": SharePointConnector.CONNECTOR_NAME,
|
||||
"description": SharePointConnector.CONNECTOR_DESCRIPTION,
|
||||
"icon": SharePointConnector.CONNECTOR_ICON,
|
||||
"available": self._is_connector_available("sharepoint")
|
||||
},
|
||||
"onedrive": {
|
||||
"name": OneDriveConnector.CONNECTOR_NAME,
|
||||
"description": OneDriveConnector.CONNECTOR_DESCRIPTION,
|
||||
"icon": OneDriveConnector.CONNECTOR_ICON,
|
||||
"available": self._is_connector_available("onedrive")
|
||||
}
|
||||
}
|
||||
|
||||
def _is_connector_available(self, connector_type: str) -> bool:
|
||||
"""Check if a connector type is available (has required env vars)"""
|
||||
try:
|
||||
temp_config = ConnectionConfig(
|
||||
connection_id="temp",
|
||||
connector_type=connector_type,
|
||||
name="temp",
|
||||
config={}
|
||||
)
|
||||
connector = self._create_connector(temp_config)
|
||||
# Try to get credentials to check if env vars are set
|
||||
connector.get_client_id()
|
||||
connector.get_client_secret()
|
||||
return True
|
||||
except (ValueError, NotImplementedError):
|
||||
return False
|
||||
|
||||
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
||||
"""Factory method to create connector instances"""
|
||||
if config.connector_type == "google_drive":
|
||||
return GoogleDriveConnector(config.config)
|
||||
elif config.connector_type == "sharepoint":
|
||||
return SharePointConnector(config.config)
|
||||
elif config.connector_type == "onedrive":
|
||||
return OneDriveConnector(config.config)
|
||||
elif config.connector_type == "box":
|
||||
# Future: BoxConnector(config.config)
|
||||
raise NotImplementedError("Box connector not implemented yet")
|
||||
|
|
|
|||
|
|
@ -133,6 +133,11 @@ class GoogleDriveConnector(BaseConnector):
|
|||
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "Google Drive"
|
||||
CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents"
|
||||
CONNECTOR_ICON = "google-drive"
|
||||
|
||||
# Supported file types that can be processed by docling
|
||||
SUPPORTED_MIMETYPES = {
|
||||
'application/pdf',
|
||||
|
|
@ -363,6 +368,10 @@ class GoogleDriveConnector(BaseConnector):
|
|||
group_permissions=group_permissions
|
||||
)
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract Google Drive channel ID from webhook headers"""
|
||||
return headers.get('x-goog-channel-id')
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
"""Handle Google Drive webhook notification"""
|
||||
if not self._authenticated:
|
||||
|
|
|
|||
|
|
@ -13,10 +13,16 @@ class GoogleDriveOAuth:
|
|||
"""Handles Google Drive OAuth authentication flow"""
|
||||
|
||||
SCOPES = [
|
||||
'openid',
|
||||
'email',
|
||||
'profile',
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
|
||||
|
||||
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
|
|
|||
4
src/connectors/onedrive/__init__.py
Normal file
4
src/connectors/onedrive/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .connector import OneDriveConnector
|
||||
from .oauth import OneDriveOAuth
|
||||
|
||||
__all__ = ["OneDriveConnector", "OneDriveOAuth"]
|
||||
193
src/connectors/onedrive/connector.py
Normal file
193
src/connectors/onedrive/connector.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||
from .oauth import OneDriveOAuth
|
||||
|
||||
|
||||
class OneDriveConnector(BaseConnector):
|
||||
"""OneDrive connector using Microsoft Graph API"""
|
||||
|
||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "OneDrive"
|
||||
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
|
||||
CONNECTOR_ICON = "onedrive"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.oauth = OneDriveOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=config.get("token_file", "onedrive_token.json"),
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id")
|
||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
if await self.oauth.is_authenticated():
|
||||
self._authenticated = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
webhook_url = self.config.get("webhook_url")
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook_url required in config for subscriptions")
|
||||
|
||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||
body = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": webhook_url,
|
||||
"resource": "/me/drive/root",
|
||||
"expirationDateTime": expiration,
|
||||
"clientState": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/subscriptions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
self.subscription_id = data["id"]
|
||||
return self.subscription_id
|
||||
|
||||
async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
params = {"$top": str(limit)}
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/me/drive/root/children",
|
||||
params=params,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("value", []):
|
||||
if item.get("file"):
|
||||
files.append({
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||
"webViewLink": item.get("webUrl"),
|
||||
"createdTime": item.get("createdDateTime"),
|
||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||
})
|
||||
|
||||
next_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||
|
||||
return {"files": files, "nextPageToken": next_token}
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
meta_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}", headers=headers)
|
||||
meta_resp.raise_for_status()
|
||||
metadata = meta_resp.json()
|
||||
|
||||
content_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers)
|
||||
content_resp.raise_for_status()
|
||||
content = content_resp.content
|
||||
|
||||
perm_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers)
|
||||
perm_resp.raise_for_status()
|
||||
permissions = perm_resp.json()
|
||||
|
||||
acl = self._parse_permissions(metadata, permissions)
|
||||
modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
|
||||
document = ConnectorDocument(
|
||||
id=metadata["id"],
|
||||
filename=metadata["name"],
|
||||
mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||
content=content,
|
||||
source_url=metadata.get("webUrl"),
|
||||
acl=acl,
|
||||
modified_time=modified,
|
||||
created_time=created,
|
||||
metadata={"size": metadata.get("size")},
|
||||
)
|
||||
return document
|
||||
|
||||
def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL:
|
||||
acl = DocumentACL()
|
||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||
if owner:
|
||||
acl.owner = owner
|
||||
for perm in permissions.get("value", []):
|
||||
role = perm.get("roles", ["read"])[0]
|
||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||
if not grantee:
|
||||
continue
|
||||
user = grantee.get("user")
|
||||
if user and user.get("email"):
|
||||
acl.user_permissions[user["email"]] = role
|
||||
group = grantee.get("group")
|
||||
if group and group.get("email"):
|
||||
acl.group_permissions[group["email"]] = role
|
||||
return acl
|
||||
|
||||
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||
"""Handle Microsoft Graph webhook validation"""
|
||||
if request_method == "GET":
|
||||
validation_token = query_params.get("validationtoken") or query_params.get("validationToken")
|
||||
if validation_token:
|
||||
return validation_token
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract SharePoint subscription ID from webhook payload"""
|
||||
values = payload.get('value', [])
|
||||
return values[0].get('subscriptionId') if values else None
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
values = payload.get("value", [])
|
||||
file_ids = []
|
||||
for item in values:
|
||||
resource_data = item.get("resourceData", {})
|
||||
file_id = resource_data.get("id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||
if not self._authenticated:
|
||||
return False
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
return resp.status_code in (200, 204)
|
||||
83
src/connectors/onedrive/oauth.py
Normal file
83
src/connectors/onedrive/oauth.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import aiofiles
|
||||
from typing import Optional
|
||||
import msal
|
||||
|
||||
|
||||
class OneDriveOAuth:
|
||||
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"Files.Read.All",
|
||||
]
|
||||
|
||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, token_file: str = "onedrive_token.json", authority: str = "https://login.microsoftonline.com/common"):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.authority = authority
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
|
||||
# Load existing cache if available
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self.token_cache.deserialize(f.read())
|
||||
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file"""
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(self.token_cache.serialize())
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri)
|
||||
|
||||
async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
return False
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get an access token for Microsoft Graph"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
raise ValueError("Not authenticated")
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" not in result:
|
||||
raise ValueError(result.get("error_description") or "Failed to acquire access token")
|
||||
return result["access_token"]
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear token cache and remove token file"""
|
||||
self.token_cache.clear()
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
|
@ -5,6 +5,8 @@ from typing import Dict, Any, List, Optional
|
|||
|
||||
from .base import BaseConnector, ConnectorDocument
|
||||
from .google_drive import GoogleDriveConnector
|
||||
from .sharepoint import SharePointConnector
|
||||
from .onedrive import OneDriveConnector
|
||||
from .connection_manager import ConnectionManager
|
||||
|
||||
|
||||
|
|
@ -28,7 +30,7 @@ class ConnectorService:
|
|||
"""Get a connector by connection ID"""
|
||||
return await self.connection_manager.get_connector(connection_id)
|
||||
|
||||
async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None) -> Dict[str, Any]:
|
||||
async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Process a document from a connector using existing processing pipeline"""
|
||||
|
||||
# Create temporary file from document content
|
||||
|
|
@ -54,7 +56,7 @@ class ConnectorService:
|
|||
# If successfully indexed, update the indexed documents with connector metadata
|
||||
if result["status"] == "indexed":
|
||||
# Update all chunks with connector-specific metadata
|
||||
await self._update_connector_metadata(document, owner_user_id, jwt_token)
|
||||
await self._update_connector_metadata(document, owner_user_id, connector_type, jwt_token)
|
||||
|
||||
return {
|
||||
**result,
|
||||
|
|
@ -66,7 +68,7 @@ class ConnectorService:
|
|||
# Clean up temporary file
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None):
|
||||
async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None):
|
||||
"""Update indexed chunks with connector-specific metadata"""
|
||||
# Find all chunks for this document
|
||||
query = {
|
||||
|
|
@ -86,7 +88,7 @@ class ConnectorService:
|
|||
update_body = {
|
||||
"doc": {
|
||||
"source_url": document.source_url,
|
||||
"connector_type": "google_drive", # Could be passed as parameter
|
||||
"connector_type": connector_type,
|
||||
# Additional ACL info beyond owner (already set by process_file_common)
|
||||
"allowed_users": document.acl.allowed_users,
|
||||
"allowed_groups": document.acl.allowed_groups,
|
||||
|
|
|
|||
4
src/connectors/sharepoint/__init__.py
Normal file
4
src/connectors/sharepoint/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .connector import SharePointConnector
|
||||
from .oauth import SharePointOAuth
|
||||
|
||||
__all__ = ["SharePointConnector", "SharePointOAuth"]
|
||||
196
src/connectors/sharepoint/connector.py
Normal file
196
src/connectors/sharepoint/connector.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||
from .oauth import SharePointOAuth
|
||||
|
||||
|
||||
class SharePointConnector(BaseConnector):
|
||||
"""SharePoint Sites connector using Microsoft Graph API"""
|
||||
|
||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "SharePoint"
|
||||
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
|
||||
CONNECTOR_ICON = "sharepoint"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.oauth = SharePointOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=config.get("token_file", "sharepoint_token.json"),
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id")
|
||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
# SharePoint site configuration
|
||||
self.site_id = config.get("site_id") # Required for SharePoint
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
if await self.oauth.is_authenticated():
|
||||
self._authenticated = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
webhook_url = self.config.get("webhook_url")
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook_url required in config for subscriptions")
|
||||
|
||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||
body = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": webhook_url,
|
||||
"resource": f"/sites/{self.site_id}/drive/root",
|
||||
"expirationDateTime": expiration,
|
||||
"clientState": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/subscriptions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
self.subscription_id = data["id"]
|
||||
return self.subscription_id
|
||||
|
||||
async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
params = {"$top": str(limit)}
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
|
||||
params=params,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("value", []):
|
||||
if item.get("file"):
|
||||
files.append({
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||
"webViewLink": item.get("webUrl"),
|
||||
"createdTime": item.get("createdDateTime"),
|
||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||
})
|
||||
|
||||
next_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||
|
||||
return {"files": files, "nextPageToken": next_token}
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
meta_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", headers=headers)
|
||||
meta_resp.raise_for_status()
|
||||
metadata = meta_resp.json()
|
||||
|
||||
content_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", headers=headers)
|
||||
content_resp.raise_for_status()
|
||||
content = content_resp.content
|
||||
|
||||
perm_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", headers=headers)
|
||||
perm_resp.raise_for_status()
|
||||
permissions = perm_resp.json()
|
||||
|
||||
acl = self._parse_permissions(metadata, permissions)
|
||||
modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
|
||||
document = ConnectorDocument(
|
||||
id=metadata["id"],
|
||||
filename=metadata["name"],
|
||||
mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||
content=content,
|
||||
source_url=metadata.get("webUrl"),
|
||||
acl=acl,
|
||||
modified_time=modified,
|
||||
created_time=created,
|
||||
metadata={"size": metadata.get("size")},
|
||||
)
|
||||
return document
|
||||
|
||||
def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL:
|
||||
acl = DocumentACL()
|
||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||
if owner:
|
||||
acl.owner = owner
|
||||
for perm in permissions.get("value", []):
|
||||
role = perm.get("roles", ["read"])[0]
|
||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||
if not grantee:
|
||||
continue
|
||||
user = grantee.get("user")
|
||||
if user and user.get("email"):
|
||||
acl.user_permissions[user["email"]] = role
|
||||
group = grantee.get("group")
|
||||
if group and group.get("email"):
|
||||
acl.group_permissions[group["email"]] = role
|
||||
return acl
|
||||
|
||||
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||
"""Handle Microsoft Graph webhook validation"""
|
||||
if request_method == "GET":
|
||||
validation_token = query_params.get("validationtoken") or query_params.get("validationToken")
|
||||
if validation_token:
|
||||
return validation_token
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract SharePoint subscription ID from webhook payload"""
|
||||
values = payload.get('value', [])
|
||||
return values[0].get('subscriptionId') if values else None
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
values = payload.get("value", [])
|
||||
file_ids = []
|
||||
for item in values:
|
||||
resource_data = item.get("resourceData", {})
|
||||
file_id = resource_data.get("id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||
if not self._authenticated:
|
||||
return False
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
return resp.status_code in (200, 204)
|
||||
84
src/connectors/sharepoint/oauth.py
Normal file
84
src/connectors/sharepoint/oauth.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import aiofiles
|
||||
from typing import Optional
|
||||
import msal
|
||||
|
||||
|
||||
class SharePointOAuth:
|
||||
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"Files.Read.All",
|
||||
"Sites.Read.All",
|
||||
]
|
||||
|
||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, token_file: str = "sharepoint_token.json", authority: str = "https://login.microsoftonline.com/common"):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.authority = authority
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
|
||||
# Load existing cache if available
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self.token_cache.deserialize(f.read())
|
||||
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file"""
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(self.token_cache.serialize())
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri)
|
||||
|
||||
async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
return False
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get an access token for Microsoft Graph"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
raise ValueError("Not authenticated")
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" not in result:
|
||||
raise ValueError(result.get("error_description") or "Failed to acquire access token")
|
||||
return result["access_token"]
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear token cache and remove token file"""
|
||||
self.token_cache.clear()
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
|
@ -278,6 +278,13 @@ def create_app():
|
|||
), methods=["POST"]),
|
||||
|
||||
# Connector endpoints
|
||||
Route("/connectors",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.list_connectors,
|
||||
connector_service=services['connector_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
|
||||
Route("/connectors/{connector_type}/sync",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.connector_sync,
|
||||
|
|
|
|||
|
|
@ -63,9 +63,10 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
file_id = item # item is the connector file ID
|
||||
file_info = self.file_info_map.get(file_id)
|
||||
|
||||
# Get the connector
|
||||
# Get the connector and connection info
|
||||
connector = await self.connector_service.get_connector(self.connection_id)
|
||||
if not connector:
|
||||
connection = await self.connector_service.connection_manager.get_connection(self.connection_id)
|
||||
if not connector or not connection:
|
||||
raise ValueError(f"Connection '{self.connection_id}' not found")
|
||||
|
||||
# Get file content from connector (the connector will fetch metadata if needed)
|
||||
|
|
@ -76,7 +77,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
||||
|
||||
# Process using existing pipeline
|
||||
result = await self.connector_service.process_connector_document(document, self.user_id)
|
||||
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type)
|
||||
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ from typing import Optional
|
|||
|
||||
from config.settings import WEBHOOK_BASE_URL
|
||||
from session_manager import SessionManager
|
||||
from connectors.google_drive.oauth import GoogleDriveOAuth
|
||||
from connectors.onedrive.oauth import OneDriveOAuth
|
||||
from connectors.sharepoint.oauth import SharePointOAuth
|
||||
from connectors.google_drive import GoogleDriveConnector
|
||||
from connectors.onedrive import OneDriveConnector
|
||||
from connectors.sharepoint import SharePointConnector
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, session_manager: SessionManager, connector_service=None):
|
||||
|
|
@ -15,11 +21,16 @@ class AuthService:
|
|||
self.connector_service = connector_service
|
||||
self.used_auth_codes = set() # Track used authorization codes
|
||||
|
||||
async def init_oauth(self, provider: str, purpose: str, connection_name: str,
|
||||
async def init_oauth(self, connector_type: str, purpose: str, connection_name: str,
|
||||
redirect_uri: str, user_id: str = None) -> dict:
|
||||
"""Initialize OAuth flow for authentication or data source connection"""
|
||||
if provider != "google":
|
||||
raise ValueError("Unsupported provider")
|
||||
# Validate connector_type based on purpose
|
||||
if purpose == "app_auth" and connector_type != "google_drive":
|
||||
raise ValueError("Only Google login supported for app authentication")
|
||||
elif purpose == "data_source" and connector_type not in ["google_drive", "onedrive", "sharepoint"]:
|
||||
raise ValueError(f"Unsupported connector type: {connector_type}")
|
||||
elif purpose not in ["app_auth", "data_source"]:
|
||||
raise ValueError(f"Unsupported purpose: {purpose}")
|
||||
|
||||
if not redirect_uri:
|
||||
raise ValueError("redirect_uri is required")
|
||||
|
|
@ -27,20 +38,19 @@ class AuthService:
|
|||
# We'll validate client credentials when creating the connector
|
||||
|
||||
# Create connection configuration
|
||||
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||||
token_file = f"{connector_type}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||||
config = {
|
||||
"token_file": token_file,
|
||||
"provider": provider,
|
||||
"connector_type": connector_type,
|
||||
"purpose": purpose,
|
||||
"redirect_uri": redirect_uri
|
||||
}
|
||||
|
||||
# Only add webhook URL if WEBHOOK_BASE_URL is configured
|
||||
if WEBHOOK_BASE_URL:
|
||||
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{provider}_drive/webhook"
|
||||
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{connector_type}/webhook"
|
||||
|
||||
# Create connection in manager (always use _drive connector type as it handles OAuth)
|
||||
connector_type = f"{provider}_drive"
|
||||
# Create connection in manager
|
||||
connection_id = await self.connector_service.connection_manager.create_connection(
|
||||
connector_type=connector_type,
|
||||
name=connection_name,
|
||||
|
|
@ -48,25 +58,38 @@ class AuthService:
|
|||
user_id=user_id
|
||||
)
|
||||
|
||||
# Return OAuth configuration for client-side flow
|
||||
scopes = [
|
||||
'openid', 'email', 'profile',
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
# Get client_id from environment variable (same as connector would do)
|
||||
# Get OAuth configuration from connector and OAuth classes
|
||||
import os
|
||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
|
||||
# Map connector types to their connector and OAuth classes
|
||||
connector_class_map = {
|
||||
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
|
||||
"onedrive": (OneDriveConnector, OneDriveOAuth),
|
||||
"sharepoint": (SharePointConnector, SharePointOAuth)
|
||||
}
|
||||
|
||||
connector_class, oauth_class = connector_class_map.get(connector_type, (None, None))
|
||||
if not connector_class or not oauth_class:
|
||||
raise ValueError(f"No classes found for connector type: {connector_type}")
|
||||
|
||||
# Get scopes from OAuth class
|
||||
scopes = oauth_class.SCOPES
|
||||
|
||||
# Get endpoints from OAuth class
|
||||
auth_endpoint = oauth_class.AUTH_ENDPOINT
|
||||
token_endpoint = oauth_class.TOKEN_ENDPOINT
|
||||
|
||||
# Get client_id from environment variable using connector's env var name
|
||||
client_id = os.getenv(connector_class.CLIENT_ID_ENV_VAR)
|
||||
if not client_id:
|
||||
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable not set")
|
||||
raise ValueError(f"{connector_class.CLIENT_ID_ENV_VAR} environment variable not set")
|
||||
|
||||
oauth_config = {
|
||||
"client_id": client_id,
|
||||
"scopes": scopes,
|
||||
"redirect_uri": redirect_uri,
|
||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_endpoint": "https://oauth2.googleapis.com/token"
|
||||
"authorization_endpoint": auth_endpoint,
|
||||
"token_endpoint": token_endpoint
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -98,10 +121,23 @@ class AuthService:
|
|||
if not redirect_uri:
|
||||
raise ValueError("Redirect URI not found in connection config")
|
||||
|
||||
token_url = "https://oauth2.googleapis.com/token"
|
||||
# Get connector to access client credentials
|
||||
# Get connector to access client credentials and endpoints
|
||||
connector = self.connector_service.connection_manager._create_connector(connection_config)
|
||||
|
||||
# Get token endpoint from connector type
|
||||
connector_type = connection_config.connector_type
|
||||
connector_class_map = {
|
||||
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
|
||||
"onedrive": (OneDriveConnector, OneDriveOAuth),
|
||||
"sharepoint": (SharePointConnector, SharePointOAuth)
|
||||
}
|
||||
|
||||
connector_class, oauth_class = connector_class_map.get(connector_type, (None, None))
|
||||
if not connector_class or not oauth_class:
|
||||
raise ValueError(f"No classes found for connector type: {connector_type}")
|
||||
|
||||
token_url = oauth_class.TOKEN_ENDPOINT
|
||||
|
||||
token_payload = {
|
||||
"code": authorization_code,
|
||||
"client_id": connector.get_client_id(),
|
||||
|
|
@ -119,14 +155,18 @@ class AuthService:
|
|||
token_data = token_response.json()
|
||||
|
||||
# Store tokens in the token file (without client_secret)
|
||||
# Use actual scopes from OAuth response
|
||||
granted_scopes = token_data.get("scope")
|
||||
if not granted_scopes:
|
||||
raise ValueError(f"OAuth provider for {connector_type} did not return granted scopes in token response")
|
||||
|
||||
# OAuth providers typically return scopes as a space-separated string
|
||||
scopes = granted_scopes.split(" ") if isinstance(granted_scopes, str) else granted_scopes
|
||||
|
||||
token_file_data = {
|
||||
"token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"scopes": [
|
||||
"openid", "email", "profile",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly"
|
||||
]
|
||||
"scopes": scopes
|
||||
}
|
||||
|
||||
# Add expiry if provided
|
||||
|
|
|
|||
21
uv.lock
generated
21
uv.lock
generated
|
|
@ -989,6 +989,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msal"
|
||||
version = "1.33.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d5/da/81acbe0c1fd7e9e4ec35f55dadeba9833a847b9a6ba2e2d1e4432da901dd/msal-1.33.0.tar.gz", hash = "sha256:836ad80faa3e25a7d71015c990ce61f704a87328b1e73bcbb0623a18cbf17510", size = 153801, upload-time = "2025-07-22T19:36:33.693Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/86/5b/fbc73e91f7727ae1e79b21ed833308e99dc11cc1cd3d4717f579775de5e9/msal-1.33.0-py3-none-any.whl", hash = "sha256:c0cd41cecf8eaed733ee7e3be9e040291eba53b0f262d3ae9c58f38b04244273", size = 116853, upload-time = "2025-07-22T19:36:32.403Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.6.3"
|
||||
|
|
@ -1328,6 +1342,7 @@ dependencies = [
|
|||
{ name = "google-auth-httplib2" },
|
||||
{ name = "google-auth-oauthlib" },
|
||||
{ name = "httpx" },
|
||||
{ name = "msal" },
|
||||
{ name = "opensearch-py", extra = ["async"] },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "python-multipart" },
|
||||
|
|
@ -1346,6 +1361,7 @@ requires-dist = [
|
|||
{ name = "google-auth-httplib2", specifier = ">=0.2.0" },
|
||||
{ name = "google-auth-oauthlib", specifier = ">=1.2.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "msal", specifier = ">=1.29.0" },
|
||||
{ name = "opensearch-py", extras = ["async"], specifier = ">=3.0.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||
|
|
@ -1661,6 +1677,11 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
crypto = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pylatexenc"
|
||||
version = "2.10"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue