diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 4932c3fcc..51c11db4f 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -28,8 +28,8 @@ from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from common.constants import RetCode, TaskStatus -from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource -from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES +from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource +from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.misc_utils import get_uuid from rag.utils.redis_conn import REDIS_CONN from api.apps import login_required, current_user @@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result" WEB_FLOW_TTL_SECS = 15 * 60 -def _web_state_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}" +def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth state. + + The default prefix keeps backward compatibility for Google Drive. + When source_type == "gmail", a different prefix is used so that + Drive/Gmail flows don't clash in Redis. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_state" + else: + prefix = GOOGLE_WEB_FLOW_STATE_PREFIX + return f"{prefix}:{flow_id}" -def _web_result_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}" +def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth result. + + Mirrors _web_state_cache_key logic for result storage. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_result" + else: + prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX + return f"{prefix}:{flow_id}" def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]: @@ -146,19 +164,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} -async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): +async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" escaped_message = escape(message) payload_json = json.dumps( { - "type": "ragflow-google-drive-oauth", + # TODO(google-oauth): include connector type (drive/gmail) in payload type if needed + "type": f"ragflow-google-{source}-oauth", "status": status, "flowId": flow_id or "", "message": message, } ) - html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format( + # TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type + html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format( + title=f"Google {source.capitalize()} Authorization", heading="Authorization complete" if success else "Authorization failed", message=escaped_message, payload_json=payload_json, @@ -169,20 +190,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): return response -@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") -async def start_google_drive_web_oauth(): - if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: +async def start_google_web_oauth(): + source = request.args.get("type", "drive") + if source not in ("drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") + + if source == "gmail": + redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GMAIL] + else: + redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "drive" else GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "drive" else DocumentSource.GMAIL] + + if not redirect_uri: return get_json_result( code=RetCode.SERVER_ERROR, - message="Google Drive OAuth redirect URI is not configured on the server.", + message="Google OAuth redirect URI is not configured on the server.", ) req = await request.json or {} raw_credentials = req.get("credentials", "") + try: credentials = _load_credentials(raw_credentials) + print(credentials) except ValueError as exc: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) @@ -199,8 +233,8 @@ async def start_google_drive_web_oauth(): flow_id = str(uuid.uuid4()) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri authorization_url, _ = flow.authorization_url( access_type="offline", include_granted_scopes="true", @@ -219,7 +253,7 @@ async def start_google_drive_web_oauth(): "client_config": client_config, "created_at": int(time.time()), } - REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) return get_json_result( data={ @@ -230,60 +264,122 @@ async def start_google_drive_web_oauth(): ) -@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 -async def google_drive_web_oauth_callback(): +@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_gmail_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") + source = "gmail" + if source not in ("drive", "gmail"): + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + error_description = request.args.get("error_description") or error if not state_id: - return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.") + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) - state_cache = REDIS_CONN.get(_web_state_cache_key(state_id)) + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) if not state_cache: - return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.") + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") if not client_config: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) if error: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) code = request.args.get("code") if not code: - return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.") + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) + flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI flow.fetch_token(code=code) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) creds_json = flow.credentials.to_json() result_payload = { "user_id": state_obj.get("user_id"), "credentials": creds_json, } - REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS) - REDIS_CONN.delete(_web_state_cache_key(state_id)) + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) - return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.") + print("\n\n", _web_result_cache_key(state_id, source), "\n\n") + + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_drive_web_oauth_callback(): + state_id = request.args.get("state") + error = request.args.get("error") + source = "drive" + if source not in ("drive", "gmail"): + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + + error_description = request.args.get("error_description") or error + + if not state_id: + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) + + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) + if not state_cache: + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) + + state_obj = json.loads(state_cache) + client_config = state_obj.get("client_config") + if not client_config: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) + + if error: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) + + code = request.args.get("code") + if not code: + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) + + try: + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) + flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow.fetch_token(code=code) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to exchange Google OAuth code: %s", exc) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) + + creds_json = flow.credentials.to_json() + result_payload = { + "user_id": state_obj.get("user_id"), + "credentials": creds_json, + } + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) + +@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") -async def poll_google_drive_web_result(): +async def poll_google_web_result(): req = await request.json or {} + source = request.args.get("type") + if source not in ("drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") flow_id = req.get("flow_id") - cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) + cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source)) if not cache_raw: return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") @@ -291,5 +387,5 @@ async def poll_google_drive_web_result(): if result.get("user_id") != current_user.id: return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") - REDIS_CONN.delete(_web_result_cache_key(flow_id)) + REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) diff --git a/common/data_source/config.py b/common/data_source/config.py index 0c038c6d7..e489541c8 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -215,6 +215,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") +GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index 67ebfae98..78e2818c9 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -1,4 +1,6 @@ +import json import logging +import sys from typing import Any from google.oauth2.credentials import Credentials as OAuthCredentials @@ -12,7 +14,7 @@ from common.data_source.google_util.resource import get_admin_service, get_gmail from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection -from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc +from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc # Constants for Gmail API fields THREAD_LIST_FIELDS = "nextPageToken, threads(id)" @@ -67,7 +69,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, message_data += f"{name}: {value}\n" message_body_text: str = get_message_body(payload) - return TextSection(link=link, text=message_body_text + message_data), metadata @@ -94,7 +95,6 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: from_emails[email] = display_name if not from_emails.get(email) else None else: other_emails[email] = display_name if not other_emails.get(email) else None - if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") @@ -103,7 +103,7 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: updated_at_datetime = None if updated_at: - updated_at_datetime = time_str_to_utc(updated_at) + updated_at_datetime = gmail_time_str_to_utc(updated_at) thread_id = full_thread.get("id") if not thread_id: @@ -115,10 +115,19 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = "(no subject)" + combined_sections = "\n\n".join( + sec.text for sec in sections if hasattr(sec, "text") + ) + blob = combined_sections.encode("utf-8") + size_bytes = len(blob) + extension = '.txt' + return Document( id=thread_id, semantic_identifier=semantic_identifier, - sections=sections, + blob=blob, + size_bytes=size_bytes, + extension=extension, source=DocumentSource.GMAIL, primary_owners=primary_owners, secondary_owners=secondary_owners, @@ -214,15 +223,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): q=query, continue_on_404_or_403=True, ): - full_threads = _execute_single_retrieval( + full_thread = _execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, - list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread["id"], continue_on_404_or_403=True, ) - full_thread = list(full_threads)[0] doc = thread_to_document(full_thread, user_email) if doc is None: continue @@ -310,4 +317,29 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if __name__ == "__main__": - pass + import time + import os + from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.INFO) + try: + email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com") + creds = get_credentials_from_env(email, oauth=True) + print("Credentials loaded successfully") + print(f"{creds=}") + + connector = GmailConnector(batch_size=2) + print("GmailConnector initialized") + connector.load_credentials(creds) + print("Credentials loaded into connector") + + print("Gmail is ready to use") + + for file in connector._fetch_threads( + int(time.time()) - 3 * 24 * 60 * 60, + int(time.time()), + ): + for f in file: + print(f) + print("\n\n") + except Exception as e: + logging.exception(f"Error loading credentials: {e}") \ No newline at end of file diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index fb88d0ed0..a2e6116cd 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -32,7 +32,6 @@ from common.data_source.google_drive.file_retrieval import ( from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS -from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict @@ -1138,39 +1137,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) -def get_credentials_from_env(email: str, oauth: bool = False) -> dict: - try: - if oauth: - raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] - else: - raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] - except KeyError: - raise ValueError("Missing Google Drive credentials in environment variables") - - try: - credential_dict = json.loads(raw_credential_string) - except json.JSONDecodeError: - raise ValueError("Invalid JSON in Google Drive credentials") - - if oauth: - credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE) - - refried_credential_string = json.dumps(credential_dict) - - DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" - DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" - DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" - - cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - - return { - cred_key: refried_credential_string, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, - DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", - } - - class CheckpointOutputWrapper: """ Wraps a CheckpointOutput generator to give things back in a more digestible format. @@ -1236,7 +1202,7 @@ def yield_all_docs_from_checkpoint_connector( if __name__ == "__main__": import time - + from common.data_source.google_util.util import get_credentials_from_env logging.basicConfig(level=logging.DEBUG) try: @@ -1245,7 +1211,7 @@ if __name__ == "__main__": creds = get_credentials_from_env(email, oauth=True) print("Credentials loaded successfully") print(f"{creds=}") - + sys.exit(0) connector = GoogleDriveConnector( include_shared_drives=False, shared_drive_urls=None, diff --git a/common/data_source/google_util/constant.py b/common/data_source/google_util/constant.py index 8ab75fa14..858ee31c8 100644 --- a/common/data_source/google_util/constant.py +++ b/common/data_source/google_util/constant.py @@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste SCOPE_INSTRUCTIONS = "" -GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """ +GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """
-