sharepont / opendrive
This commit is contained in:
parent
c2992d7680
commit
c111295bac
26 changed files with 931 additions and 138 deletions
|
|
@ -7,6 +7,9 @@ OPENSEARCH_PASSWORD=OSisgendb1!
|
||||||
# make here https://console.cloud.google.com/apis/credentials
|
# make here https://console.cloud.google.com/apis/credentials
|
||||||
GOOGLE_OAUTH_CLIENT_ID=
|
GOOGLE_OAUTH_CLIENT_ID=
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||||
|
# Azure app registration credentials for SharePoint/OneDrive
|
||||||
|
MICROSOFT_GRAPH_OAUTH_CLIENT_ID=
|
||||||
|
MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET=
|
||||||
# Optional dns routable from google (etc.) to handle continous ingest (ngrok works)
|
# Optional dns routable from google (etc.) to handle continous ingest (ngrok works)
|
||||||
WEBHOOK_BASE_URL=
|
WEBHOOK_BASE_URL=
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client"
|
"use client"
|
||||||
|
|
||||||
import { useState, useEffect, Suspense } from "react"
|
import { useState, useEffect, useCallback, Suspense } from "react"
|
||||||
import { useSearchParams } from "next/navigation"
|
import { useSearchParams } from "next/navigation"
|
||||||
import { Button } from "@/components/ui/button"
|
import { Button } from "@/components/ui/button"
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
||||||
|
|
@ -52,23 +52,43 @@ function ConnectorsPage() {
|
||||||
const [syncResults, setSyncResults] = useState<{[key: string]: SyncResult | null}>({})
|
const [syncResults, setSyncResults] = useState<{[key: string]: SyncResult | null}>({})
|
||||||
const [maxFiles, setMaxFiles] = useState<number>(10)
|
const [maxFiles, setMaxFiles] = useState<number>(10)
|
||||||
|
|
||||||
// Function definitions first
|
// Helper function to get connector icon
|
||||||
const checkConnectorStatuses = async () => {
|
const getConnectorIcon = (iconName: string) => {
|
||||||
// Initialize connectors list
|
const iconMap: { [key: string]: React.ReactElement } = {
|
||||||
setConnectors([
|
'google-drive': <div className="w-8 h-8 bg-blue-500 rounded flex items-center justify-center text-white font-bold">G</div>,
|
||||||
{
|
'sharepoint': <div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold">SP</div>,
|
||||||
id: "google_drive",
|
'onedrive': <div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold">OD</div>,
|
||||||
name: "Google Drive",
|
}
|
||||||
description: "Connect your Google Drive to automatically sync documents",
|
return iconMap[iconName] || <div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold">?</div>
|
||||||
icon: <div className="w-8 h-8 bg-blue-500 rounded flex items-center justify-center text-white font-bold">G</div>,
|
}
|
||||||
status: "not_connected",
|
|
||||||
type: "google_drive"
|
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
||||||
|
// Function definitions first
|
||||||
|
const checkConnectorStatuses = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
|
// Fetch available connectors from backend
|
||||||
|
const connectorsResponse = await fetch('/api/connectors')
|
||||||
|
if (!connectorsResponse.ok) {
|
||||||
|
throw new Error('Failed to load connectors')
|
||||||
|
}
|
||||||
|
|
||||||
|
const connectorsResult = await connectorsResponse.json()
|
||||||
|
const connectorTypes = Object.keys(connectorsResult.connectors)
|
||||||
|
|
||||||
|
// Initialize connectors list with metadata from backend
|
||||||
|
const initialConnectors = connectorTypes
|
||||||
|
.filter(type => connectorsResult.connectors[type].available) // Only show available connectors
|
||||||
|
.map(type => ({
|
||||||
|
id: type,
|
||||||
|
name: connectorsResult.connectors[type].name,
|
||||||
|
description: connectorsResult.connectors[type].description,
|
||||||
|
icon: getConnectorIcon(connectorsResult.connectors[type].icon),
|
||||||
|
status: "not_connected" as const,
|
||||||
|
type: type
|
||||||
|
}))
|
||||||
|
|
||||||
|
setConnectors(initialConnectors)
|
||||||
|
|
||||||
// Check status for each connector type
|
// Check status for each connector type
|
||||||
const connectorTypes = ["google_drive"]
|
|
||||||
|
|
||||||
for (const connectorType of connectorTypes) {
|
for (const connectorType of connectorTypes) {
|
||||||
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
||||||
|
|
@ -92,7 +112,7 @@ function ConnectorsPage() {
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to check connector statuses:', error)
|
console.error('Failed to check connector statuses:', error)
|
||||||
}
|
}
|
||||||
}
|
}, [setConnectors])
|
||||||
|
|
||||||
const handleConnect = async (connector: Connector) => {
|
const handleConnect = async (connector: Connector) => {
|
||||||
setIsConnecting(connector.id)
|
setIsConnecting(connector.id)
|
||||||
|
|
@ -110,8 +130,8 @@ function ConnectorsPage() {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
provider: connector.type.replace('_drive', ''), // "google_drive" -> "google"
|
connector_type: connector.type,
|
||||||
purpose: "data_source",
|
purpose: "data_source",
|
||||||
name: `${connector.name} Connection`,
|
name: `${connector.name} Connection`,
|
||||||
redirect_uri: redirectUri
|
redirect_uri: redirectUri
|
||||||
}),
|
}),
|
||||||
|
|
@ -262,7 +282,7 @@ function ConnectorsPage() {
|
||||||
url.searchParams.delete('oauth_success')
|
url.searchParams.delete('oauth_success')
|
||||||
window.history.replaceState({}, '', url.toString())
|
window.history.replaceState({}, '', url.toString())
|
||||||
}
|
}
|
||||||
}, [searchParams, isAuthenticated])
|
}, [searchParams, isAuthenticated, checkConnectorStatuses])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-8">
|
<div className="space-y-8">
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client"
|
"use client"
|
||||||
|
|
||||||
import { useState, useEffect, Suspense } from "react"
|
import { useState, useEffect, useCallback, Suspense } from "react"
|
||||||
import { useSearchParams } from "next/navigation"
|
import { useSearchParams } from "next/navigation"
|
||||||
import { Button } from "@/components/ui/button"
|
import { Button } from "@/components/ui/button"
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
|
||||||
|
|
@ -150,27 +150,59 @@ function KnowledgeSourcesPage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connector functions
|
// Helper function to get connector icon
|
||||||
const checkConnectorStatuses = async () => {
|
const getConnectorIcon = (iconName: string) => {
|
||||||
setConnectors([
|
const iconMap: { [key: string]: React.ReactElement } = {
|
||||||
{
|
'google-drive': (
|
||||||
id: "google_drive",
|
<div className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||||
name: "Google Drive",
|
G
|
||||||
description: "Connect your Google Drive to automatically sync documents",
|
</div>
|
||||||
icon: (
|
),
|
||||||
<div
|
'sharepoint': (
|
||||||
className="w-8 h-8 bg-blue-600 rounded flex items-center justify-center text-white font-bold leading-none shrink-0"
|
<div className="w-8 h-8 bg-blue-700 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||||
>
|
SP
|
||||||
G
|
</div>
|
||||||
</div>
|
),
|
||||||
),
|
'onedrive': (
|
||||||
status: "not_connected",
|
<div className="w-8 h-8 bg-blue-400 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||||
type: "google_drive"
|
OD
|
||||||
},
|
</div>
|
||||||
])
|
),
|
||||||
|
}
|
||||||
|
return iconMap[iconName] || (
|
||||||
|
<div className="w-8 h-8 bg-gray-500 rounded flex items-center justify-center text-white font-bold leading-none shrink-0">
|
||||||
|
?
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connector functions
|
||||||
|
const checkConnectorStatuses = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
const connectorTypes = ["google_drive"]
|
// Fetch available connectors from backend
|
||||||
|
const connectorsResponse = await fetch('/api/connectors')
|
||||||
|
if (!connectorsResponse.ok) {
|
||||||
|
throw new Error('Failed to load connectors')
|
||||||
|
}
|
||||||
|
|
||||||
|
const connectorsResult = await connectorsResponse.json()
|
||||||
|
const connectorTypes = Object.keys(connectorsResult.connectors)
|
||||||
|
|
||||||
|
// Initialize connectors list with metadata from backend
|
||||||
|
const initialConnectors = connectorTypes
|
||||||
|
.filter(type => connectorsResult.connectors[type].available) // Only show available connectors
|
||||||
|
.map(type => ({
|
||||||
|
id: type,
|
||||||
|
name: connectorsResult.connectors[type].name,
|
||||||
|
description: connectorsResult.connectors[type].description,
|
||||||
|
icon: getConnectorIcon(connectorsResult.connectors[type].icon),
|
||||||
|
status: "not_connected" as const,
|
||||||
|
type: type
|
||||||
|
}))
|
||||||
|
|
||||||
|
setConnectors(initialConnectors)
|
||||||
|
|
||||||
|
// Check status for each connector type
|
||||||
|
|
||||||
for (const connectorType of connectorTypes) {
|
for (const connectorType of connectorTypes) {
|
||||||
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
const response = await fetch(`/api/connectors/${connectorType}/status`)
|
||||||
|
|
@ -194,18 +226,27 @@ function KnowledgeSourcesPage() {
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to check connector statuses:', error)
|
console.error('Failed to check connector statuses:', error)
|
||||||
}
|
}
|
||||||
}
|
}, [])
|
||||||
|
|
||||||
const handleConnect = async (connector: Connector) => {
|
const handleConnect = async (connector: Connector) => {
|
||||||
setIsConnecting(connector.id)
|
setIsConnecting(connector.id)
|
||||||
setSyncResults(prev => ({ ...prev, [connector.id]: null }))
|
setSyncResults(prev => ({ ...prev, [connector.id]: null }))
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/api/connectors/${connector.type}/connect`, {
|
// Use the shared auth callback URL, same as connectors page
|
||||||
|
const redirectUri = `${window.location.origin}/auth/callback`
|
||||||
|
|
||||||
|
const response = await fetch('/api/auth/init', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
connector_type: connector.type,
|
||||||
|
purpose: "data_source",
|
||||||
|
name: `${connector.name} Connection`,
|
||||||
|
redirect_uri: redirectUri
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
|
|
@ -305,7 +346,7 @@ function KnowledgeSourcesPage() {
|
||||||
url.searchParams.delete('oauth_success')
|
url.searchParams.delete('oauth_success')
|
||||||
window.history.replaceState({}, '', url.toString())
|
window.history.replaceState({}, '', url.toString())
|
||||||
}
|
}
|
||||||
}, [searchParams, isAuthenticated])
|
}, [searchParams, isAuthenticated, checkConnectorStatuses])
|
||||||
|
|
||||||
// Fetch global stats using match-all wildcard
|
// Fetch global stats using match-all wildcard
|
||||||
const fetchStats = async () => {
|
const fetchStats = async () => {
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ export function AuthProvider({ children }: AuthProviderProps) {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
provider: 'google',
|
connector_type: 'google_drive',
|
||||||
purpose: 'app_auth',
|
purpose: 'app_auth',
|
||||||
name: 'App Authentication',
|
name: 'App Authentication',
|
||||||
redirect_uri: redirectUri
|
redirect_uri: redirectUri
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"use client"
|
"use client"
|
||||||
|
|
||||||
import React, { createContext, useContext, useState, useEffect, ReactNode } from 'react'
|
import React, { createContext, useContext, useState, ReactNode } from 'react'
|
||||||
|
|
||||||
interface KnowledgeFilter {
|
interface KnowledgeFilter {
|
||||||
id: string
|
id: string
|
||||||
|
|
@ -61,9 +61,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
||||||
const parsed = JSON.parse(filter.query_data) as ParsedQueryData
|
const parsed = JSON.parse(filter.query_data) as ParsedQueryData
|
||||||
setParsedFilterData(parsed)
|
setParsedFilterData(parsed)
|
||||||
|
|
||||||
// Store in localStorage for persistence across page reloads
|
|
||||||
localStorage.setItem('selectedKnowledgeFilter', JSON.stringify(filter))
|
|
||||||
|
|
||||||
// Auto-open panel when filter is selected
|
// Auto-open panel when filter is selected
|
||||||
setIsPanelOpen(true)
|
setIsPanelOpen(true)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -72,7 +69,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
setParsedFilterData(null)
|
setParsedFilterData(null)
|
||||||
localStorage.removeItem('selectedKnowledgeFilter')
|
|
||||||
setIsPanelOpen(false)
|
setIsPanelOpen(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -93,19 +89,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro
|
||||||
setIsPanelOpen(false) // Close panel but keep filter selected
|
setIsPanelOpen(false) // Close panel but keep filter selected
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load persisted filter on mount
|
|
||||||
useEffect(() => {
|
|
||||||
try {
|
|
||||||
const saved = localStorage.getItem('selectedKnowledgeFilter')
|
|
||||||
if (saved) {
|
|
||||||
const filter = JSON.parse(saved) as KnowledgeFilter
|
|
||||||
setSelectedFilter(filter)
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error loading persisted filter:', error)
|
|
||||||
localStorage.removeItem('selectedKnowledgeFilter')
|
|
||||||
}
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
const value: KnowledgeFilterContextType = {
|
const value: KnowledgeFilterContextType = {
|
||||||
selectedFilter,
|
selectedFilter,
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ dependencies = [
|
||||||
"google-api-python-client>=2.143.0",
|
"google-api-python-client>=2.143.0",
|
||||||
"google-auth-httplib2>=0.2.0",
|
"google-auth-httplib2>=0.2.0",
|
||||||
"google-auth-oauthlib>=1.2.0",
|
"google-auth-oauthlib>=1.2.0",
|
||||||
|
"msal>=1.29.0",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"opensearch-py[async]>=3.0.0",
|
"opensearch-py[async]>=3.0.0",
|
||||||
"pyjwt>=2.8.0",
|
"pyjwt>=2.8.0",
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,16 @@ async def auth_init(request: Request, auth_service, session_manager):
|
||||||
"""Initialize OAuth flow for authentication or data source connection"""
|
"""Initialize OAuth flow for authentication or data source connection"""
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
provider = data.get("provider")
|
connector_type = data.get("connector_type")
|
||||||
purpose = data.get("purpose", "data_source")
|
purpose = data.get("purpose", "data_source")
|
||||||
connection_name = data.get("name", f"{provider}_{purpose}")
|
connection_name = data.get("name", f"{connector_type}_{purpose}")
|
||||||
redirect_uri = data.get("redirect_uri")
|
redirect_uri = data.get("redirect_uri")
|
||||||
|
|
||||||
user = getattr(request.state, 'user', None)
|
user = getattr(request.state, 'user', None)
|
||||||
user_id = user.user_id if user else None
|
user_id = user.user_id if user else None
|
||||||
|
|
||||||
result = await auth_service.init_oauth(
|
result = await auth_service.init_oauth(
|
||||||
provider, purpose, connection_name, redirect_uri, user_id
|
connector_type, purpose, connection_name, redirect_uri, user_id
|
||||||
)
|
)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,14 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse, PlainTextResponse
|
||||||
|
|
||||||
|
async def list_connectors(request: Request, connector_service, session_manager):
|
||||||
|
"""List available connector types with metadata"""
|
||||||
|
try:
|
||||||
|
connector_types = connector_service.connection_manager.get_available_connector_types()
|
||||||
|
return JSONResponse({"connectors": connector_types})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error listing connectors: {e}")
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
async def connector_sync(request: Request, connector_service, session_manager):
|
async def connector_sync(request: Request, connector_service, session_manager):
|
||||||
"""Sync files from all active connections of a connector type"""
|
"""Sync files from all active connections of a connector type"""
|
||||||
|
|
@ -85,7 +94,29 @@ async def connector_status(request: Request, connector_service, session_manager)
|
||||||
async def connector_webhook(request: Request, connector_service, session_manager):
|
async def connector_webhook(request: Request, connector_service, session_manager):
|
||||||
"""Handle webhook notifications from any connector type"""
|
"""Handle webhook notifications from any connector type"""
|
||||||
connector_type = request.path_params.get("connector_type")
|
connector_type = request.path_params.get("connector_type")
|
||||||
|
|
||||||
|
# Handle webhook validation (connector-specific)
|
||||||
|
temp_config = {"token_file": "temp.json"}
|
||||||
|
from connectors.connection_manager import ConnectionConfig
|
||||||
|
temp_connection = ConnectionConfig(
|
||||||
|
connection_id="temp",
|
||||||
|
connector_type=connector_type,
|
||||||
|
name="temp",
|
||||||
|
config=temp_config
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
temp_connector = connector_service.connection_manager._create_connector(temp_connection)
|
||||||
|
validation_response = temp_connector.handle_webhook_validation(
|
||||||
|
request.method,
|
||||||
|
dict(request.headers),
|
||||||
|
dict(request.query_params)
|
||||||
|
)
|
||||||
|
if validation_response:
|
||||||
|
return PlainTextResponse(validation_response)
|
||||||
|
except (NotImplementedError, ValueError):
|
||||||
|
# Connector type not found or validation not needed
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the raw payload and headers
|
# Get the raw payload and headers
|
||||||
payload = {}
|
payload = {}
|
||||||
|
|
@ -109,8 +140,13 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
|
|
||||||
print(f"[WEBHOOK] {connector_type} notification received")
|
print(f"[WEBHOOK] {connector_type} notification received")
|
||||||
|
|
||||||
# Extract channel/subscription ID from headers (Google Drive specific)
|
# Extract channel/subscription ID using connector-specific method
|
||||||
channel_id = headers.get('x-goog-channel-id')
|
try:
|
||||||
|
temp_connector = connector_service.connection_manager._create_connector(temp_connection)
|
||||||
|
channel_id = temp_connector.extract_webhook_channel_id(payload, headers)
|
||||||
|
except (NotImplementedError, ValueError):
|
||||||
|
channel_id = None
|
||||||
|
|
||||||
if not channel_id:
|
if not channel_id:
|
||||||
print(f"[WEBHOOK] No channel ID found in {connector_type} webhook")
|
print(f"[WEBHOOK] No channel ID found in {connector_type} webhook")
|
||||||
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
|
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
|
||||||
|
|
@ -132,7 +168,6 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
connector = await connector_service._get_connector(active_connections[0].connection_id)
|
||||||
if connector:
|
if connector:
|
||||||
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}")
|
||||||
resource_id = headers.get('x-goog-resource-id')
|
|
||||||
await connector.cleanup_subscription(channel_id, resource_id)
|
await connector.cleanup_subscription(channel_id, resource_id)
|
||||||
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,27 +3,25 @@ from starlette.responses import JSONResponse
|
||||||
|
|
||||||
async def search(request: Request, search_service, session_manager):
|
async def search(request: Request, search_service, session_manager):
|
||||||
"""Search for documents"""
|
"""Search for documents"""
|
||||||
payload = await request.json()
|
try:
|
||||||
query = payload.get("query")
|
payload = await request.json()
|
||||||
if not query:
|
query = payload.get("query")
|
||||||
return JSONResponse({"error": "Query is required"}, status_code=400)
|
if not query:
|
||||||
|
return JSONResponse({"error": "Query is required"}, status_code=400)
|
||||||
filters = payload.get("filters", {}) # Optional filters, defaults to empty dict
|
|
||||||
limit = payload.get("limit", 10) # Optional limit, defaults to 10
|
filters = payload.get("filters", {}) # Optional filters, defaults to empty dict
|
||||||
score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0
|
limit = payload.get("limit", 10) # Optional limit, defaults to 10
|
||||||
|
score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0
|
||||||
user = request.state.user
|
|
||||||
# Extract JWT token from cookie for OpenSearch OIDC auth
|
user = request.state.user
|
||||||
jwt_token = request.cookies.get("auth_token")
|
# Extract JWT token from cookie for OpenSearch OIDC auth
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold)
|
|
||||||
|
result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold)
|
||||||
# Return appropriate HTTP status codes
|
|
||||||
if result.get("success"):
|
|
||||||
return JSONResponse(result, status_code=200)
|
return JSONResponse(result, status_code=200)
|
||||||
else:
|
except Exception as e:
|
||||||
error_msg = result.get("error", "")
|
error_msg = str(e)
|
||||||
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
||||||
return JSONResponse(result, status_code=403)
|
return JSONResponse({"error": error_msg}, status_code=403)
|
||||||
else:
|
else:
|
||||||
return JSONResponse(result, status_code=500)
|
return JSONResponse({"error": error_msg}, status_code=500)
|
||||||
|
|
@ -4,22 +4,20 @@ from starlette.responses import JSONResponse
|
||||||
|
|
||||||
async def upload(request: Request, document_service, session_manager):
|
async def upload(request: Request, document_service, session_manager):
|
||||||
"""Upload a single file"""
|
"""Upload a single file"""
|
||||||
form = await request.form()
|
try:
|
||||||
upload_file = form["file"]
|
form = await request.form()
|
||||||
user = request.state.user
|
upload_file = form["file"]
|
||||||
jwt_token = request.cookies.get("auth_token")
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token)
|
|
||||||
|
result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token)
|
||||||
# Return appropriate HTTP status codes
|
|
||||||
if result.get("success"):
|
|
||||||
return JSONResponse(result, status_code=201) # Created
|
return JSONResponse(result, status_code=201) # Created
|
||||||
else:
|
except Exception as e:
|
||||||
error_msg = result.get("error", "")
|
error_msg = str(e)
|
||||||
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
if "AuthenticationException" in error_msg or "access denied" in error_msg.lower():
|
||||||
return JSONResponse(result, status_code=403)
|
return JSONResponse({"error": error_msg}, status_code=403)
|
||||||
else:
|
else:
|
||||||
return JSONResponse(result, status_code=500)
|
return JSONResponse({"error": error_msg}, status_code=500)
|
||||||
|
|
||||||
async def upload_path(request: Request, task_service, session_manager):
|
async def upload_path(request: Request, task_service, session_manager):
|
||||||
"""Upload all files from a directory path"""
|
"""Upload all files from a directory path"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from .base import BaseConnector
|
from .base import BaseConnector
|
||||||
from .google_drive import GoogleDriveConnector
|
from .google_drive import GoogleDriveConnector
|
||||||
|
from .sharepoint import SharePointConnector
|
||||||
|
from .onedrive import OneDriveConnector
|
||||||
|
|
||||||
__all__ = ["BaseConnector", "GoogleDriveConnector"]
|
__all__ = ["BaseConnector", "GoogleDriveConnector", "SharePointConnector", "OneDriveConnector"]
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,11 @@ class BaseConnector(ABC):
|
||||||
CLIENT_ID_ENV_VAR: str = None
|
CLIENT_ID_ENV_VAR: str = None
|
||||||
CLIENT_SECRET_ENV_VAR: str = None
|
CLIENT_SECRET_ENV_VAR: str = None
|
||||||
|
|
||||||
|
# Connector metadata for UI
|
||||||
|
CONNECTOR_NAME: str = None
|
||||||
|
CONNECTOR_DESCRIPTION: str = None
|
||||||
|
CONNECTOR_ICON: str = None # Icon identifier or emoji
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
def __init__(self, config: Dict[str, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._authenticated = False
|
self._authenticated = False
|
||||||
|
|
@ -105,6 +110,17 @@ class BaseConnector(ABC):
|
||||||
"""Handle webhook notification. Returns list of affected file IDs."""
|
"""Handle webhook notification. Returns list of affected file IDs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Handle webhook validation (e.g., for subscription setup).
|
||||||
|
Returns validation response if applicable, None otherwise.
|
||||||
|
Default implementation returns None (no validation needed)."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract channel/subscription ID from webhook payload/headers.
|
||||||
|
Must be implemented by each connector."""
|
||||||
|
raise NotImplementedError(f"{self.__class__.__name__} must implement extract_webhook_channel_id")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||||
"""Clean up subscription"""
|
"""Clean up subscription"""
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from pathlib import Path
|
||||||
|
|
||||||
from .base import BaseConnector
|
from .base import BaseConnector
|
||||||
from .google_drive import GoogleDriveConnector
|
from .google_drive import GoogleDriveConnector
|
||||||
|
from .sharepoint import SharePointConnector
|
||||||
|
from .onedrive import OneDriveConnector
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -186,10 +188,54 @@ class ConnectionManager:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_available_connector_types(self) -> Dict[str, Dict[str, str]]:
|
||||||
|
"""Get available connector types with their metadata"""
|
||||||
|
return {
|
||||||
|
"google_drive": {
|
||||||
|
"name": GoogleDriveConnector.CONNECTOR_NAME,
|
||||||
|
"description": GoogleDriveConnector.CONNECTOR_DESCRIPTION,
|
||||||
|
"icon": GoogleDriveConnector.CONNECTOR_ICON,
|
||||||
|
"available": self._is_connector_available("google_drive")
|
||||||
|
},
|
||||||
|
"sharepoint": {
|
||||||
|
"name": SharePointConnector.CONNECTOR_NAME,
|
||||||
|
"description": SharePointConnector.CONNECTOR_DESCRIPTION,
|
||||||
|
"icon": SharePointConnector.CONNECTOR_ICON,
|
||||||
|
"available": self._is_connector_available("sharepoint")
|
||||||
|
},
|
||||||
|
"onedrive": {
|
||||||
|
"name": OneDriveConnector.CONNECTOR_NAME,
|
||||||
|
"description": OneDriveConnector.CONNECTOR_DESCRIPTION,
|
||||||
|
"icon": OneDriveConnector.CONNECTOR_ICON,
|
||||||
|
"available": self._is_connector_available("onedrive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_connector_available(self, connector_type: str) -> bool:
|
||||||
|
"""Check if a connector type is available (has required env vars)"""
|
||||||
|
try:
|
||||||
|
temp_config = ConnectionConfig(
|
||||||
|
connection_id="temp",
|
||||||
|
connector_type=connector_type,
|
||||||
|
name="temp",
|
||||||
|
config={}
|
||||||
|
)
|
||||||
|
connector = self._create_connector(temp_config)
|
||||||
|
# Try to get credentials to check if env vars are set
|
||||||
|
connector.get_client_id()
|
||||||
|
connector.get_client_secret()
|
||||||
|
return True
|
||||||
|
except (ValueError, NotImplementedError):
|
||||||
|
return False
|
||||||
|
|
||||||
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
||||||
"""Factory method to create connector instances"""
|
"""Factory method to create connector instances"""
|
||||||
if config.connector_type == "google_drive":
|
if config.connector_type == "google_drive":
|
||||||
return GoogleDriveConnector(config.config)
|
return GoogleDriveConnector(config.config)
|
||||||
|
elif config.connector_type == "sharepoint":
|
||||||
|
return SharePointConnector(config.config)
|
||||||
|
elif config.connector_type == "onedrive":
|
||||||
|
return OneDriveConnector(config.config)
|
||||||
elif config.connector_type == "box":
|
elif config.connector_type == "box":
|
||||||
# Future: BoxConnector(config.config)
|
# Future: BoxConnector(config.config)
|
||||||
raise NotImplementedError("Box connector not implemented yet")
|
raise NotImplementedError("Box connector not implemented yet")
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,11 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
||||||
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
|
# Connector metadata
|
||||||
|
CONNECTOR_NAME = "Google Drive"
|
||||||
|
CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents"
|
||||||
|
CONNECTOR_ICON = "google-drive"
|
||||||
|
|
||||||
# Supported file types that can be processed by docling
|
# Supported file types that can be processed by docling
|
||||||
SUPPORTED_MIMETYPES = {
|
SUPPORTED_MIMETYPES = {
|
||||||
'application/pdf',
|
'application/pdf',
|
||||||
|
|
@ -363,6 +368,10 @@ class GoogleDriveConnector(BaseConnector):
|
||||||
group_permissions=group_permissions
|
group_permissions=group_permissions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract Google Drive channel ID from webhook headers"""
|
||||||
|
return headers.get('x-goog-channel-id')
|
||||||
|
|
||||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
"""Handle Google Drive webhook notification"""
|
"""Handle Google Drive webhook notification"""
|
||||||
if not self._authenticated:
|
if not self._authenticated:
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,16 @@ class GoogleDriveOAuth:
|
||||||
"""Handles Google Drive OAuth authentication flow"""
|
"""Handles Google Drive OAuth authentication flow"""
|
||||||
|
|
||||||
SCOPES = [
|
SCOPES = [
|
||||||
|
'openid',
|
||||||
|
'email',
|
||||||
|
'profile',
|
||||||
'https://www.googleapis.com/auth/drive.readonly',
|
'https://www.googleapis.com/auth/drive.readonly',
|
||||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
|
||||||
|
|
||||||
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
|
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
|
|
|
||||||
4
src/connectors/onedrive/__init__.py
Normal file
4
src/connectors/onedrive/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .connector import OneDriveConnector
|
||||||
|
from .oauth import OneDriveOAuth
|
||||||
|
|
||||||
|
__all__ = ["OneDriveConnector", "OneDriveOAuth"]
|
||||||
193
src/connectors/onedrive/connector.py
Normal file
193
src/connectors/onedrive/connector.py
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
import httpx
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||||
|
from .oauth import OneDriveOAuth
|
||||||
|
|
||||||
|
|
||||||
|
class OneDriveConnector(BaseConnector):
|
||||||
|
"""OneDrive connector using Microsoft Graph API"""
|
||||||
|
|
||||||
|
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||||
|
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
|
# Connector metadata
|
||||||
|
CONNECTOR_NAME = "OneDrive"
|
||||||
|
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
|
||||||
|
CONNECTOR_ICON = "onedrive"
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
super().__init__(config)
|
||||||
|
self.oauth = OneDriveOAuth(
|
||||||
|
client_id=self.get_client_id(),
|
||||||
|
client_secret=self.get_client_secret(),
|
||||||
|
token_file=config.get("token_file", "onedrive_token.json"),
|
||||||
|
)
|
||||||
|
self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id")
|
||||||
|
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
async def authenticate(self) -> bool:
|
||||||
|
if await self.oauth.is_authenticated():
|
||||||
|
self._authenticated = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def setup_subscription(self) -> str:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
webhook_url = self.config.get("webhook_url")
|
||||||
|
if not webhook_url:
|
||||||
|
raise ValueError("webhook_url required in config for subscriptions")
|
||||||
|
|
||||||
|
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||||
|
body = {
|
||||||
|
"changeType": "created,updated,deleted",
|
||||||
|
"notificationUrl": webhook_url,
|
||||||
|
"resource": "/me/drive/root",
|
||||||
|
"expirationDateTime": expiration,
|
||||||
|
"clientState": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{self.base_url}/subscriptions",
|
||||||
|
json=body,
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
self.subscription_id = data["id"]
|
||||||
|
return self.subscription_id
|
||||||
|
|
||||||
|
async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
params = {"$top": str(limit)}
|
||||||
|
if page_token:
|
||||||
|
params["$skiptoken"] = page_token
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{self.base_url}/me/drive/root/children",
|
||||||
|
params=params,
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for item in data.get("value", []):
|
||||||
|
if item.get("file"):
|
||||||
|
files.append({
|
||||||
|
"id": item["id"],
|
||||||
|
"name": item["name"],
|
||||||
|
"mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||||
|
"webViewLink": item.get("webUrl"),
|
||||||
|
"createdTime": item.get("createdDateTime"),
|
||||||
|
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||||
|
})
|
||||||
|
|
||||||
|
next_token = None
|
||||||
|
next_link = data.get("@odata.nextLink")
|
||||||
|
if next_link:
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||||
|
|
||||||
|
return {"files": files, "nextPageToken": next_token}
|
||||||
|
|
||||||
|
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
meta_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}", headers=headers)
|
||||||
|
meta_resp.raise_for_status()
|
||||||
|
metadata = meta_resp.json()
|
||||||
|
|
||||||
|
content_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers)
|
||||||
|
content_resp.raise_for_status()
|
||||||
|
content = content_resp.content
|
||||||
|
|
||||||
|
perm_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers)
|
||||||
|
perm_resp.raise_for_status()
|
||||||
|
permissions = perm_resp.json()
|
||||||
|
|
||||||
|
acl = self._parse_permissions(metadata, permissions)
|
||||||
|
modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||||
|
created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||||
|
|
||||||
|
document = ConnectorDocument(
|
||||||
|
id=metadata["id"],
|
||||||
|
filename=metadata["name"],
|
||||||
|
mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||||
|
content=content,
|
||||||
|
source_url=metadata.get("webUrl"),
|
||||||
|
acl=acl,
|
||||||
|
modified_time=modified,
|
||||||
|
created_time=created,
|
||||||
|
metadata={"size": metadata.get("size")},
|
||||||
|
)
|
||||||
|
return document
|
||||||
|
|
||||||
|
def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL:
|
||||||
|
acl = DocumentACL()
|
||||||
|
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||||
|
if owner:
|
||||||
|
acl.owner = owner
|
||||||
|
for perm in permissions.get("value", []):
|
||||||
|
role = perm.get("roles", ["read"])[0]
|
||||||
|
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||||
|
if not grantee:
|
||||||
|
continue
|
||||||
|
user = grantee.get("user")
|
||||||
|
if user and user.get("email"):
|
||||||
|
acl.user_permissions[user["email"]] = role
|
||||||
|
group = grantee.get("group")
|
||||||
|
if group and group.get("email"):
|
||||||
|
acl.group_permissions[group["email"]] = role
|
||||||
|
return acl
|
||||||
|
|
||||||
|
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Handle Microsoft Graph webhook validation"""
|
||||||
|
if request_method == "GET":
|
||||||
|
validation_token = query_params.get("validationtoken") or query_params.get("validationToken")
|
||||||
|
if validation_token:
|
||||||
|
return validation_token
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract SharePoint subscription ID from webhook payload"""
|
||||||
|
values = payload.get('value', [])
|
||||||
|
return values[0].get('subscriptionId') if values else None
|
||||||
|
|
||||||
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
|
values = payload.get("value", [])
|
||||||
|
file_ids = []
|
||||||
|
for item in values:
|
||||||
|
resource_data = item.get("resourceData", {})
|
||||||
|
file_id = resource_data.get("id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(file_id)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||||
|
if not self._authenticated:
|
||||||
|
return False
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.delete(
|
||||||
|
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
return resp.status_code in (200, 204)
|
||||||
83
src/connectors/onedrive/oauth.py
Normal file
83
src/connectors/onedrive/oauth.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
import os
|
||||||
|
import aiofiles
|
||||||
|
from typing import Optional
|
||||||
|
import msal
|
||||||
|
|
||||||
|
|
||||||
|
class OneDriveOAuth:
|
||||||
|
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||||
|
|
||||||
|
SCOPES = [
|
||||||
|
"offline_access",
|
||||||
|
"Files.Read.All",
|
||||||
|
]
|
||||||
|
|
||||||
|
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
|
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, token_file: str = "onedrive_token.json", authority: str = "https://login.microsoftonline.com/common"):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.token_file = token_file
|
||||||
|
self.authority = authority
|
||||||
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
|
||||||
|
# Load existing cache if available
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
with open(self.token_file, "r") as f:
|
||||||
|
self.token_cache.deserialize(f.read())
|
||||||
|
|
||||||
|
self.app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_credential=self.client_secret,
|
||||||
|
authority=self.authority,
|
||||||
|
token_cache=self.token_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_cache(self):
|
||||||
|
"""Persist the token cache to file"""
|
||||||
|
async with aiofiles.open(self.token_file, "w") as f:
|
||||||
|
await f.write(self.token_cache.serialize())
|
||||||
|
|
||||||
|
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||||
|
"""Create authorization URL for OAuth flow"""
|
||||||
|
return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri)
|
||||||
|
|
||||||
|
async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool:
|
||||||
|
"""Handle OAuth callback and exchange code for tokens"""
|
||||||
|
result = self.app.acquire_token_by_authorization_code(
|
||||||
|
authorization_code,
|
||||||
|
scopes=self.SCOPES,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
)
|
||||||
|
if "access_token" in result:
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||||
|
|
||||||
|
async def is_authenticated(self) -> bool:
|
||||||
|
"""Check if we have valid credentials"""
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if not accounts:
|
||||||
|
return False
|
||||||
|
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||||
|
if "access_token" in result:
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_access_token(self) -> str:
|
||||||
|
"""Get an access token for Microsoft Graph"""
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if not accounts:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||||
|
if "access_token" not in result:
|
||||||
|
raise ValueError(result.get("error_description") or "Failed to acquire access token")
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
|
async def revoke_credentials(self):
|
||||||
|
"""Clear token cache and remove token file"""
|
||||||
|
self.token_cache.clear()
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
os.remove(self.token_file)
|
||||||
|
|
@ -5,6 +5,8 @@ from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
from .base import BaseConnector, ConnectorDocument
|
from .base import BaseConnector, ConnectorDocument
|
||||||
from .google_drive import GoogleDriveConnector
|
from .google_drive import GoogleDriveConnector
|
||||||
|
from .sharepoint import SharePointConnector
|
||||||
|
from .onedrive import OneDriveConnector
|
||||||
from .connection_manager import ConnectionManager
|
from .connection_manager import ConnectionManager
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,7 +30,7 @@ class ConnectorService:
|
||||||
"""Get a connector by connection ID"""
|
"""Get a connector by connection ID"""
|
||||||
return await self.connection_manager.get_connector(connection_id)
|
return await self.connection_manager.get_connector(connection_id)
|
||||||
|
|
||||||
async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None) -> Dict[str, Any]:
|
async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
"""Process a document from a connector using existing processing pipeline"""
|
"""Process a document from a connector using existing processing pipeline"""
|
||||||
|
|
||||||
# Create temporary file from document content
|
# Create temporary file from document content
|
||||||
|
|
@ -54,7 +56,7 @@ class ConnectorService:
|
||||||
# If successfully indexed, update the indexed documents with connector metadata
|
# If successfully indexed, update the indexed documents with connector metadata
|
||||||
if result["status"] == "indexed":
|
if result["status"] == "indexed":
|
||||||
# Update all chunks with connector-specific metadata
|
# Update all chunks with connector-specific metadata
|
||||||
await self._update_connector_metadata(document, owner_user_id, jwt_token)
|
await self._update_connector_metadata(document, owner_user_id, connector_type, jwt_token)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**result,
|
**result,
|
||||||
|
|
@ -66,7 +68,7 @@ class ConnectorService:
|
||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
os.unlink(tmp_file.name)
|
os.unlink(tmp_file.name)
|
||||||
|
|
||||||
async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None):
|
async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None):
|
||||||
"""Update indexed chunks with connector-specific metadata"""
|
"""Update indexed chunks with connector-specific metadata"""
|
||||||
# Find all chunks for this document
|
# Find all chunks for this document
|
||||||
query = {
|
query = {
|
||||||
|
|
@ -86,7 +88,7 @@ class ConnectorService:
|
||||||
update_body = {
|
update_body = {
|
||||||
"doc": {
|
"doc": {
|
||||||
"source_url": document.source_url,
|
"source_url": document.source_url,
|
||||||
"connector_type": "google_drive", # Could be passed as parameter
|
"connector_type": connector_type,
|
||||||
# Additional ACL info beyond owner (already set by process_file_common)
|
# Additional ACL info beyond owner (already set by process_file_common)
|
||||||
"allowed_users": document.acl.allowed_users,
|
"allowed_users": document.acl.allowed_users,
|
||||||
"allowed_groups": document.acl.allowed_groups,
|
"allowed_groups": document.acl.allowed_groups,
|
||||||
|
|
|
||||||
4
src/connectors/sharepoint/__init__.py
Normal file
4
src/connectors/sharepoint/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .connector import SharePointConnector
|
||||||
|
from .oauth import SharePointOAuth
|
||||||
|
|
||||||
|
__all__ = ["SharePointConnector", "SharePointOAuth"]
|
||||||
196
src/connectors/sharepoint/connector.py
Normal file
196
src/connectors/sharepoint/connector.py
Normal file
|
|
@ -0,0 +1,196 @@
|
||||||
|
import httpx
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||||
|
from .oauth import SharePointOAuth
|
||||||
|
|
||||||
|
|
||||||
|
class SharePointConnector(BaseConnector):
|
||||||
|
"""SharePoint Sites connector using Microsoft Graph API"""
|
||||||
|
|
||||||
|
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||||
|
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
|
# Connector metadata
|
||||||
|
CONNECTOR_NAME = "SharePoint"
|
||||||
|
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
|
||||||
|
CONNECTOR_ICON = "sharepoint"
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
super().__init__(config)
|
||||||
|
self.oauth = SharePointOAuth(
|
||||||
|
client_id=self.get_client_id(),
|
||||||
|
client_secret=self.get_client_secret(),
|
||||||
|
token_file=config.get("token_file", "sharepoint_token.json"),
|
||||||
|
)
|
||||||
|
self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id")
|
||||||
|
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
# SharePoint site configuration
|
||||||
|
self.site_id = config.get("site_id") # Required for SharePoint
|
||||||
|
|
||||||
|
async def authenticate(self) -> bool:
|
||||||
|
if await self.oauth.is_authenticated():
|
||||||
|
self._authenticated = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def setup_subscription(self) -> str:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
webhook_url = self.config.get("webhook_url")
|
||||||
|
if not webhook_url:
|
||||||
|
raise ValueError("webhook_url required in config for subscriptions")
|
||||||
|
|
||||||
|
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||||
|
body = {
|
||||||
|
"changeType": "created,updated,deleted",
|
||||||
|
"notificationUrl": webhook_url,
|
||||||
|
"resource": f"/sites/{self.site_id}/drive/root",
|
||||||
|
"expirationDateTime": expiration,
|
||||||
|
"clientState": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{self.base_url}/subscriptions",
|
||||||
|
json=body,
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
self.subscription_id = data["id"]
|
||||||
|
return self.subscription_id
|
||||||
|
|
||||||
|
async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
params = {"$top": str(limit)}
|
||||||
|
if page_token:
|
||||||
|
params["$skiptoken"] = page_token
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
|
||||||
|
params=params,
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for item in data.get("value", []):
|
||||||
|
if item.get("file"):
|
||||||
|
files.append({
|
||||||
|
"id": item["id"],
|
||||||
|
"name": item["name"],
|
||||||
|
"mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||||
|
"webViewLink": item.get("webUrl"),
|
||||||
|
"createdTime": item.get("createdDateTime"),
|
||||||
|
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||||
|
})
|
||||||
|
|
||||||
|
next_token = None
|
||||||
|
next_link = data.get("@odata.nextLink")
|
||||||
|
if next_link:
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||||
|
|
||||||
|
return {"files": files, "nextPageToken": next_token}
|
||||||
|
|
||||||
|
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||||
|
if not self._authenticated:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
meta_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", headers=headers)
|
||||||
|
meta_resp.raise_for_status()
|
||||||
|
metadata = meta_resp.json()
|
||||||
|
|
||||||
|
content_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", headers=headers)
|
||||||
|
content_resp.raise_for_status()
|
||||||
|
content = content_resp.content
|
||||||
|
|
||||||
|
perm_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", headers=headers)
|
||||||
|
perm_resp.raise_for_status()
|
||||||
|
permissions = perm_resp.json()
|
||||||
|
|
||||||
|
acl = self._parse_permissions(metadata, permissions)
|
||||||
|
modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||||
|
created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None)
|
||||||
|
|
||||||
|
document = ConnectorDocument(
|
||||||
|
id=metadata["id"],
|
||||||
|
filename=metadata["name"],
|
||||||
|
mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"),
|
||||||
|
content=content,
|
||||||
|
source_url=metadata.get("webUrl"),
|
||||||
|
acl=acl,
|
||||||
|
modified_time=modified,
|
||||||
|
created_time=created,
|
||||||
|
metadata={"size": metadata.get("size")},
|
||||||
|
)
|
||||||
|
return document
|
||||||
|
|
||||||
|
def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL:
|
||||||
|
acl = DocumentACL()
|
||||||
|
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||||
|
if owner:
|
||||||
|
acl.owner = owner
|
||||||
|
for perm in permissions.get("value", []):
|
||||||
|
role = perm.get("roles", ["read"])[0]
|
||||||
|
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||||
|
if not grantee:
|
||||||
|
continue
|
||||||
|
user = grantee.get("user")
|
||||||
|
if user and user.get("email"):
|
||||||
|
acl.user_permissions[user["email"]] = role
|
||||||
|
group = grantee.get("group")
|
||||||
|
if group and group.get("email"):
|
||||||
|
acl.group_permissions[group["email"]] = role
|
||||||
|
return acl
|
||||||
|
|
||||||
|
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Handle Microsoft Graph webhook validation"""
|
||||||
|
if request_method == "GET":
|
||||||
|
validation_token = query_params.get("validationtoken") or query_params.get("validationToken")
|
||||||
|
if validation_token:
|
||||||
|
return validation_token
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract SharePoint subscription ID from webhook payload"""
|
||||||
|
values = payload.get('value', [])
|
||||||
|
return values[0].get('subscriptionId') if values else None
|
||||||
|
|
||||||
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
|
values = payload.get("value", [])
|
||||||
|
file_ids = []
|
||||||
|
for item in values:
|
||||||
|
resource_data = item.get("resourceData", {})
|
||||||
|
file_id = resource_data.get("id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(file_id)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool:
|
||||||
|
if not self._authenticated:
|
||||||
|
return False
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.delete(
|
||||||
|
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
return resp.status_code in (200, 204)
|
||||||
84
src/connectors/sharepoint/oauth.py
Normal file
84
src/connectors/sharepoint/oauth.py
Normal file
|
|
@ -0,0 +1,84 @@
|
||||||
|
import os
|
||||||
|
import aiofiles
|
||||||
|
from typing import Optional
|
||||||
|
import msal
|
||||||
|
|
||||||
|
|
||||||
|
class SharePointOAuth:
|
||||||
|
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||||
|
|
||||||
|
SCOPES = [
|
||||||
|
"offline_access",
|
||||||
|
"Files.Read.All",
|
||||||
|
"Sites.Read.All",
|
||||||
|
]
|
||||||
|
|
||||||
|
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
|
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, token_file: str = "sharepoint_token.json", authority: str = "https://login.microsoftonline.com/common"):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.token_file = token_file
|
||||||
|
self.authority = authority
|
||||||
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
|
||||||
|
# Load existing cache if available
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
with open(self.token_file, "r") as f:
|
||||||
|
self.token_cache.deserialize(f.read())
|
||||||
|
|
||||||
|
self.app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_credential=self.client_secret,
|
||||||
|
authority=self.authority,
|
||||||
|
token_cache=self.token_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_cache(self):
|
||||||
|
"""Persist the token cache to file"""
|
||||||
|
async with aiofiles.open(self.token_file, "w") as f:
|
||||||
|
await f.write(self.token_cache.serialize())
|
||||||
|
|
||||||
|
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||||
|
"""Create authorization URL for OAuth flow"""
|
||||||
|
return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri)
|
||||||
|
|
||||||
|
async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool:
|
||||||
|
"""Handle OAuth callback and exchange code for tokens"""
|
||||||
|
result = self.app.acquire_token_by_authorization_code(
|
||||||
|
authorization_code,
|
||||||
|
scopes=self.SCOPES,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
)
|
||||||
|
if "access_token" in result:
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||||
|
|
||||||
|
async def is_authenticated(self) -> bool:
|
||||||
|
"""Check if we have valid credentials"""
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if not accounts:
|
||||||
|
return False
|
||||||
|
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||||
|
if "access_token" in result:
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_access_token(self) -> str:
|
||||||
|
"""Get an access token for Microsoft Graph"""
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if not accounts:
|
||||||
|
raise ValueError("Not authenticated")
|
||||||
|
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||||
|
if "access_token" not in result:
|
||||||
|
raise ValueError(result.get("error_description") or "Failed to acquire access token")
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
|
async def revoke_credentials(self):
|
||||||
|
"""Clear token cache and remove token file"""
|
||||||
|
self.token_cache.clear()
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
os.remove(self.token_file)
|
||||||
|
|
@ -278,6 +278,13 @@ def create_app():
|
||||||
), methods=["POST"]),
|
), methods=["POST"]),
|
||||||
|
|
||||||
# Connector endpoints
|
# Connector endpoints
|
||||||
|
Route("/connectors",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(connectors.list_connectors,
|
||||||
|
connector_service=services['connector_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["GET"]),
|
||||||
|
|
||||||
Route("/connectors/{connector_type}/sync",
|
Route("/connectors/{connector_type}/sync",
|
||||||
require_auth(services['session_manager'])(
|
require_auth(services['session_manager'])(
|
||||||
partial(connectors.connector_sync,
|
partial(connectors.connector_sync,
|
||||||
|
|
|
||||||
|
|
@ -63,9 +63,10 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
file_id = item # item is the connector file ID
|
file_id = item # item is the connector file ID
|
||||||
file_info = self.file_info_map.get(file_id)
|
file_info = self.file_info_map.get(file_id)
|
||||||
|
|
||||||
# Get the connector
|
# Get the connector and connection info
|
||||||
connector = await self.connector_service.get_connector(self.connection_id)
|
connector = await self.connector_service.get_connector(self.connection_id)
|
||||||
if not connector:
|
connection = await self.connector_service.connection_manager.get_connection(self.connection_id)
|
||||||
|
if not connector or not connection:
|
||||||
raise ValueError(f"Connection '{self.connection_id}' not found")
|
raise ValueError(f"Connection '{self.connection_id}' not found")
|
||||||
|
|
||||||
# Get file content from connector (the connector will fetch metadata if needed)
|
# Get file content from connector (the connector will fetch metadata if needed)
|
||||||
|
|
@ -76,7 +77,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
raise ValueError("user_id not provided to ConnectorFileProcessor")
|
||||||
|
|
||||||
# Process using existing pipeline
|
# Process using existing pipeline
|
||||||
result = await self.connector_service.process_connector_document(document, self.user_id)
|
result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type)
|
||||||
|
|
||||||
file_task.status = TaskStatus.COMPLETED
|
file_task.status = TaskStatus.COMPLETED
|
||||||
file_task.result = result
|
file_task.result = result
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,12 @@ from typing import Optional
|
||||||
|
|
||||||
from config.settings import WEBHOOK_BASE_URL
|
from config.settings import WEBHOOK_BASE_URL
|
||||||
from session_manager import SessionManager
|
from session_manager import SessionManager
|
||||||
|
from connectors.google_drive.oauth import GoogleDriveOAuth
|
||||||
|
from connectors.onedrive.oauth import OneDriveOAuth
|
||||||
|
from connectors.sharepoint.oauth import SharePointOAuth
|
||||||
|
from connectors.google_drive import GoogleDriveConnector
|
||||||
|
from connectors.onedrive import OneDriveConnector
|
||||||
|
from connectors.sharepoint import SharePointConnector
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
def __init__(self, session_manager: SessionManager, connector_service=None):
|
def __init__(self, session_manager: SessionManager, connector_service=None):
|
||||||
|
|
@ -15,11 +21,16 @@ class AuthService:
|
||||||
self.connector_service = connector_service
|
self.connector_service = connector_service
|
||||||
self.used_auth_codes = set() # Track used authorization codes
|
self.used_auth_codes = set() # Track used authorization codes
|
||||||
|
|
||||||
async def init_oauth(self, provider: str, purpose: str, connection_name: str,
|
async def init_oauth(self, connector_type: str, purpose: str, connection_name: str,
|
||||||
redirect_uri: str, user_id: str = None) -> dict:
|
redirect_uri: str, user_id: str = None) -> dict:
|
||||||
"""Initialize OAuth flow for authentication or data source connection"""
|
"""Initialize OAuth flow for authentication or data source connection"""
|
||||||
if provider != "google":
|
# Validate connector_type based on purpose
|
||||||
raise ValueError("Unsupported provider")
|
if purpose == "app_auth" and connector_type != "google_drive":
|
||||||
|
raise ValueError("Only Google login supported for app authentication")
|
||||||
|
elif purpose == "data_source" and connector_type not in ["google_drive", "onedrive", "sharepoint"]:
|
||||||
|
raise ValueError(f"Unsupported connector type: {connector_type}")
|
||||||
|
elif purpose not in ["app_auth", "data_source"]:
|
||||||
|
raise ValueError(f"Unsupported purpose: {purpose}")
|
||||||
|
|
||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
raise ValueError("redirect_uri is required")
|
raise ValueError("redirect_uri is required")
|
||||||
|
|
@ -27,20 +38,19 @@ class AuthService:
|
||||||
# We'll validate client credentials when creating the connector
|
# We'll validate client credentials when creating the connector
|
||||||
|
|
||||||
# Create connection configuration
|
# Create connection configuration
|
||||||
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
token_file = f"{connector_type}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||||||
config = {
|
config = {
|
||||||
"token_file": token_file,
|
"token_file": token_file,
|
||||||
"provider": provider,
|
"connector_type": connector_type,
|
||||||
"purpose": purpose,
|
"purpose": purpose,
|
||||||
"redirect_uri": redirect_uri
|
"redirect_uri": redirect_uri
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only add webhook URL if WEBHOOK_BASE_URL is configured
|
# Only add webhook URL if WEBHOOK_BASE_URL is configured
|
||||||
if WEBHOOK_BASE_URL:
|
if WEBHOOK_BASE_URL:
|
||||||
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{provider}_drive/webhook"
|
config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{connector_type}/webhook"
|
||||||
|
|
||||||
# Create connection in manager (always use _drive connector type as it handles OAuth)
|
# Create connection in manager
|
||||||
connector_type = f"{provider}_drive"
|
|
||||||
connection_id = await self.connector_service.connection_manager.create_connection(
|
connection_id = await self.connector_service.connection_manager.create_connection(
|
||||||
connector_type=connector_type,
|
connector_type=connector_type,
|
||||||
name=connection_name,
|
name=connection_name,
|
||||||
|
|
@ -48,25 +58,38 @@ class AuthService:
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return OAuth configuration for client-side flow
|
# Get OAuth configuration from connector and OAuth classes
|
||||||
scopes = [
|
|
||||||
'openid', 'email', 'profile',
|
|
||||||
'https://www.googleapis.com/auth/drive.readonly',
|
|
||||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get client_id from environment variable (same as connector would do)
|
|
||||||
import os
|
import os
|
||||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
|
||||||
|
# Map connector types to their connector and OAuth classes
|
||||||
|
connector_class_map = {
|
||||||
|
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
|
||||||
|
"onedrive": (OneDriveConnector, OneDriveOAuth),
|
||||||
|
"sharepoint": (SharePointConnector, SharePointOAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
connector_class, oauth_class = connector_class_map.get(connector_type, (None, None))
|
||||||
|
if not connector_class or not oauth_class:
|
||||||
|
raise ValueError(f"No classes found for connector type: {connector_type}")
|
||||||
|
|
||||||
|
# Get scopes from OAuth class
|
||||||
|
scopes = oauth_class.SCOPES
|
||||||
|
|
||||||
|
# Get endpoints from OAuth class
|
||||||
|
auth_endpoint = oauth_class.AUTH_ENDPOINT
|
||||||
|
token_endpoint = oauth_class.TOKEN_ENDPOINT
|
||||||
|
|
||||||
|
# Get client_id from environment variable using connector's env var name
|
||||||
|
client_id = os.getenv(connector_class.CLIENT_ID_ENV_VAR)
|
||||||
if not client_id:
|
if not client_id:
|
||||||
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable not set")
|
raise ValueError(f"{connector_class.CLIENT_ID_ENV_VAR} environment variable not set")
|
||||||
|
|
||||||
oauth_config = {
|
oauth_config = {
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"scopes": scopes,
|
"scopes": scopes,
|
||||||
"redirect_uri": redirect_uri,
|
"redirect_uri": redirect_uri,
|
||||||
"authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth",
|
"authorization_endpoint": auth_endpoint,
|
||||||
"token_endpoint": "https://oauth2.googleapis.com/token"
|
"token_endpoint": token_endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -98,10 +121,23 @@ class AuthService:
|
||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
raise ValueError("Redirect URI not found in connection config")
|
raise ValueError("Redirect URI not found in connection config")
|
||||||
|
|
||||||
token_url = "https://oauth2.googleapis.com/token"
|
# Get connector to access client credentials and endpoints
|
||||||
# Get connector to access client credentials
|
|
||||||
connector = self.connector_service.connection_manager._create_connector(connection_config)
|
connector = self.connector_service.connection_manager._create_connector(connection_config)
|
||||||
|
|
||||||
|
# Get token endpoint from connector type
|
||||||
|
connector_type = connection_config.connector_type
|
||||||
|
connector_class_map = {
|
||||||
|
"google_drive": (GoogleDriveConnector, GoogleDriveOAuth),
|
||||||
|
"onedrive": (OneDriveConnector, OneDriveOAuth),
|
||||||
|
"sharepoint": (SharePointConnector, SharePointOAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
connector_class, oauth_class = connector_class_map.get(connector_type, (None, None))
|
||||||
|
if not connector_class or not oauth_class:
|
||||||
|
raise ValueError(f"No classes found for connector type: {connector_type}")
|
||||||
|
|
||||||
|
token_url = oauth_class.TOKEN_ENDPOINT
|
||||||
|
|
||||||
token_payload = {
|
token_payload = {
|
||||||
"code": authorization_code,
|
"code": authorization_code,
|
||||||
"client_id": connector.get_client_id(),
|
"client_id": connector.get_client_id(),
|
||||||
|
|
@ -119,14 +155,18 @@ class AuthService:
|
||||||
token_data = token_response.json()
|
token_data = token_response.json()
|
||||||
|
|
||||||
# Store tokens in the token file (without client_secret)
|
# Store tokens in the token file (without client_secret)
|
||||||
|
# Use actual scopes from OAuth response
|
||||||
|
granted_scopes = token_data.get("scope")
|
||||||
|
if not granted_scopes:
|
||||||
|
raise ValueError(f"OAuth provider for {connector_type} did not return granted scopes in token response")
|
||||||
|
|
||||||
|
# OAuth providers typically return scopes as a space-separated string
|
||||||
|
scopes = granted_scopes.split(" ") if isinstance(granted_scopes, str) else granted_scopes
|
||||||
|
|
||||||
token_file_data = {
|
token_file_data = {
|
||||||
"token": token_data["access_token"],
|
"token": token_data["access_token"],
|
||||||
"refresh_token": token_data.get("refresh_token"),
|
"refresh_token": token_data.get("refresh_token"),
|
||||||
"scopes": [
|
"scopes": scopes
|
||||||
"openid", "email", "profile",
|
|
||||||
"https://www.googleapis.com/auth/drive.readonly",
|
|
||||||
"https://www.googleapis.com/auth/drive.metadata.readonly"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add expiry if provided
|
# Add expiry if provided
|
||||||
|
|
|
||||||
21
uv.lock
generated
21
uv.lock
generated
|
|
@ -989,6 +989,20 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" },
|
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "msal"
|
||||||
|
version = "1.33.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "cryptography" },
|
||||||
|
{ name = "pyjwt", extra = ["crypto"] },
|
||||||
|
{ name = "requests" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d5/da/81acbe0c1fd7e9e4ec35f55dadeba9833a847b9a6ba2e2d1e4432da901dd/msal-1.33.0.tar.gz", hash = "sha256:836ad80faa3e25a7d71015c990ce61f704a87328b1e73bcbb0623a18cbf17510", size = 153801, upload-time = "2025-07-22T19:36:33.693Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/5b/fbc73e91f7727ae1e79b21ed833308e99dc11cc1cd3d4717f579775de5e9/msal-1.33.0-py3-none-any.whl", hash = "sha256:c0cd41cecf8eaed733ee7e3be9e040291eba53b0f262d3ae9c58f38b04244273", size = 116853, upload-time = "2025-07-22T19:36:32.403Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multidict"
|
name = "multidict"
|
||||||
version = "6.6.3"
|
version = "6.6.3"
|
||||||
|
|
@ -1328,6 +1342,7 @@ dependencies = [
|
||||||
{ name = "google-auth-httplib2" },
|
{ name = "google-auth-httplib2" },
|
||||||
{ name = "google-auth-oauthlib" },
|
{ name = "google-auth-oauthlib" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
{ name = "msal" },
|
||||||
{ name = "opensearch-py", extra = ["async"] },
|
{ name = "opensearch-py", extra = ["async"] },
|
||||||
{ name = "pyjwt" },
|
{ name = "pyjwt" },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
|
|
@ -1346,6 +1361,7 @@ requires-dist = [
|
||||||
{ name = "google-auth-httplib2", specifier = ">=0.2.0" },
|
{ name = "google-auth-httplib2", specifier = ">=0.2.0" },
|
||||||
{ name = "google-auth-oauthlib", specifier = ">=1.2.0" },
|
{ name = "google-auth-oauthlib", specifier = ">=1.2.0" },
|
||||||
{ name = "httpx", specifier = ">=0.27.0" },
|
{ name = "httpx", specifier = ">=0.27.0" },
|
||||||
|
{ name = "msal", specifier = ">=1.29.0" },
|
||||||
{ name = "opensearch-py", extras = ["async"], specifier = ">=3.0.0" },
|
{ name = "opensearch-py", extras = ["async"], specifier = ">=3.0.0" },
|
||||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||||
|
|
@ -1661,6 +1677,11 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
crypto = [
|
||||||
|
{ name = "cryptography" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pylatexenc"
|
name = "pylatexenc"
|
||||||
version = "2.10"
|
version = "2.10"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue