diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4be4b06be..a5a44a29c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: name: ragflow_tests # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://github.com/orgs/community/discussions/26261 - if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }} + if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable != false) }} runs-on: [ "self-hosted", "ragflow-test" ] steps: # https://github.com/hmarr/debug-action diff --git a/agent/canvas.py b/agent/canvas.py index 9e95a5611..5344d70c3 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import base64 import json import logging import re @@ -25,6 +24,7 @@ from typing import Any, Union, Tuple from agent.component import component_class from agent.component.base import ComponentBase +from api.db.services.file_service import FileService from api.db.services.task_service import has_canceled from common.misc_utils import get_uuid, hash_str2int from common.exceptions import TaskCanceledException @@ -372,7 +372,7 @@ class Canvas(Graph): for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": - self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) + self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : @@ -621,22 +621,6 @@ class Canvas(Graph): def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() - def get_files(self, files: Union[None, list[dict]]) -> list[str]: - from api.db.services.file_service import FileService - if not files: - return [] - def image_to_base64(file): - return "data:{};base64,{}".format(file["mime_type"], - base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) - threads = [] - for file in files: - if file["mime_type"].find("image") >=0: - threads.append(exe.submit(image_to_base64, file)) - continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) - return [th.result() for th in threads] - def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") agent_name = self.get_component_name(agent_ids[0]) diff --git a/agent/component/begin.py b/agent/component/begin.py index b5985bb7a..1314aff74 100644 --- a/agent/component/begin.py +++ b/agent/component/begin.py @@ -14,6 +14,7 @@ # limitations under the License. # from agent.component.fillup import UserFillUpParam, UserFillUp +from api.db.services.file_service import FileService class BeginParam(UserFillUpParam): @@ -48,7 +49,7 @@ class Begin(UserFillUp): if v.get("optional") and v.get("value", None) is None: v = None else: - v = self._canvas.get_files([v["value"]]) + v = FileService.get_files([v["value"]]) else: v = v.get("value") self.set_output(k, v) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 214c8b947..5d05a7327 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -40,6 +40,11 @@ from api.db.services.canvas_service import ( CanvasTemplateService, UserCanvasService, ) +from functools import partial +from quart import request, Response, make_response +from agent.component import LLM +from api.db import CanvasCategory +from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.document_service import DocumentService from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import ( @@ -63,6 +68,15 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf from common import settings from common.constants import RetCode, StatusEnum from common.misc_utils import get_uuid +from common.constants import RetCode +from common.misc_utils import get_uuid +from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ + request_json +from agent.canvas import Canvas +from peewee import MySQLDatabase, PostgresqlDatabase +from api.db.db_models import APIToken, Task +import time + from rag.flow.pipeline import Pipeline from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN @@ -319,71 +333,10 @@ async def upload(canvas_id): return get_data_error_result(message="canvas not found.") user_id = cvs["user_id"] - def structured(filename, filetype, blob, content_type): - nonlocal user_id - if filetype == FileType.PDF.value: - blob = read_potential_broken_pdf(blob) - - location = get_uuid() - FileService.put_blob(user_id, location, blob) - - return { - "id": location, - "name": filename, - "size": sys.getsizeof(blob), - "extension": filename.split(".")[-1].lower(), - "mime_type": content_type, - "created_by": user_id, - "created_at": time.time(), - "preview_url": None - } - - if request.args.get("url"): - from crawl4ai import ( - AsyncWebCrawler, - BrowserConfig, - CrawlerRunConfig, - DefaultMarkdownGenerator, - PruningContentFilter, - CrawlResult - ) - try: - url = request.args.get("url") - filename = re.sub(r"\?.*", "", url.split("/")[-1]) - async def adownload(): - browser_config = BrowserConfig( - headless=True, - verbose=False, - ) - async with AsyncWebCrawler(config=browser_config) as crawler: - crawler_config = CrawlerRunConfig( - markdown_generator=DefaultMarkdownGenerator( - content_filter=PruningContentFilter() - ), - pdf=True, - screenshot=False - ) - result: CrawlResult = await crawler.arun( - url=url, - config=crawler_config - ) - return result - page = trio.run(adownload()) - if page.pdf: - if filename.split(".")[-1].lower() != "pdf": - filename += ".pdf" - return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"])) - - return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)) - - except Exception as e: - return server_error_response(e) - files = await request.files - file = files['file'] + file = files['file'] if files and files.get("file") else None try: - DocumentService.check_doc_health(user_id, file.filename) - return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type)) + return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url"))) except Exception as e: return server_error_response(e) diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 4932c3fcc..34da2293b 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -28,8 +28,8 @@ from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from common.constants import RetCode, TaskStatus -from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource -from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES +from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource +from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.misc_utils import get_uuid from rag.utils.redis_conn import REDIS_CONN from api.apps import login_required, current_user @@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result" WEB_FLOW_TTL_SECS = 15 * 60 -def _web_state_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}" +def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth state. + + The default prefix keeps backward compatibility for Google Drive. + When source_type == "gmail", a different prefix is used so that + Drive/Gmail flows don't clash in Redis. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_state" + else: + prefix = GOOGLE_WEB_FLOW_STATE_PREFIX + return f"{prefix}:{flow_id}" -def _web_result_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}" +def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth result. + + Mirrors _web_state_cache_key logic for result storage. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_result" + else: + prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX + return f"{prefix}:{flow_id}" def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]: @@ -146,19 +164,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} -async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): +async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" escaped_message = escape(message) payload_json = json.dumps( { - "type": "ragflow-google-drive-oauth", + # TODO(google-oauth): include connector type (drive/gmail) in payload type if needed + "type": f"ragflow-google-{source}-oauth", "status": status, "flowId": flow_id or "", "message": message, } ) - html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format( + # TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type + html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format( + title=f"Google {source.capitalize()} Authorization", heading="Authorization complete" if success else "Authorization failed", message=escaped_message, payload_json=payload_json, @@ -169,20 +190,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): return response -@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") -async def start_google_drive_web_oauth(): - if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: +async def start_google_web_oauth(): + source = request.args.get("type", "google-drive") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") + + if source == "gmail": + redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GMAIL] + else: + redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL] + + if not redirect_uri: return get_json_result( code=RetCode.SERVER_ERROR, - message="Google Drive OAuth redirect URI is not configured on the server.", + message="Google OAuth redirect URI is not configured on the server.", ) req = await request.json or {} raw_credentials = req.get("credentials", "") + try: credentials = _load_credentials(raw_credentials) + print(credentials) except ValueError as exc: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) @@ -199,8 +233,8 @@ async def start_google_drive_web_oauth(): flow_id = str(uuid.uuid4()) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri authorization_url, _ = flow.authorization_url( access_type="offline", include_granted_scopes="true", @@ -219,7 +253,7 @@ async def start_google_drive_web_oauth(): "client_config": client_config, "created_at": int(time.time()), } - REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) return get_json_result( data={ @@ -230,60 +264,122 @@ async def start_google_drive_web_oauth(): ) -@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 -async def google_drive_web_oauth_callback(): +@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_gmail_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") + source = "gmail" + if source != 'gmail': + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + error_description = request.args.get("error_description") or error if not state_id: - return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.") + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) - state_cache = REDIS_CONN.get(_web_state_cache_key(state_id)) + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) if not state_cache: - return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.") + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") if not client_config: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) if error: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) code = request.args.get("code") if not code: - return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.") + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) + flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI flow.fetch_token(code=code) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) creds_json = flow.credentials.to_json() result_payload = { "user_id": state_obj.get("user_id"), "credentials": creds_json, } - REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS) - REDIS_CONN.delete(_web_state_cache_key(state_id)) + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) - return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.") + print("\n\n", _web_result_cache_key(state_id, source), "\n\n") + + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_drive_web_oauth_callback(): + state_id = request.args.get("state") + error = request.args.get("error") + source = "google-drive" + if source not in ("google-drive", "gmail"): + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + + error_description = request.args.get("error_description") or error + + if not state_id: + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) + + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) + if not state_cache: + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) + + state_obj = json.loads(state_cache) + client_config = state_obj.get("client_config") + if not client_config: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) + + if error: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) + + code = request.args.get("code") + if not code: + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) + + try: + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) + flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow.fetch_token(code=code) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to exchange Google OAuth code: %s", exc) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) + + creds_json = flow.credentials.to_json() + result_payload = { + "user_id": state_obj.get("user_id"), + "credentials": creds_json, + } + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) + +@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") -async def poll_google_drive_web_result(): +async def poll_google_web_result(): req = await request.json or {} + source = request.args.get("type") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") flow_id = req.get("flow_id") - cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) + cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source)) if not cache_raw: return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") @@ -291,5 +387,5 @@ async def poll_google_drive_web_result(): if result.get("user_id") != current_user.id: return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") - REDIS_CONN.delete(_web_result_cache_key(flow_id)) + REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 7ec8c1587..bd2262919 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -607,7 +607,7 @@ async def get_image(image_id): @login_required @validate_request("conversation_id") async def upload_and_parse(): - files = await request.file + files = await request.files if "file" not in files: return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) @@ -705,3 +705,12 @@ async def set_meta(): return get_json_result(data=True) except Exception as e: return server_error_response(e) + +@manager.route("/upload_info", methods=["POST"]) # noqa: F821 +async def upload_info(): + files = await request.files + file = files['file'] if files and files.get("file") else None + try: + return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url"))) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index f53494541..aebf925cc 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1446,6 +1446,9 @@ async def retrieval_test(tenant_id): metadata_condition = req.get("metadata_condition", {}) or {} metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + # If metadata_condition has conditions but no docs match, return empty result + if not doc_ids and metadata_condition.get("conditions"): + return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) if metadata_condition and not doc_ids: doc_ids = ["-999"] similarity_threshold = float(req.get("similarity_threshold", 0.2)) diff --git a/api/apps/user_app.py b/api/apps/user_app.py index a54831dba..cca8abb4c 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -175,8 +175,8 @@ async def login(): response_data = user.to_json() user.access_token = get_uuid() login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) + user.update_time = current_timestamp() + user.update_date = datetime_format(datetime.now()) user.save() msg = "Welcome back!" @@ -1779,8 +1779,8 @@ async def forget(): # Auto login (reuse login flow) user.access_token = get_uuid() login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) + user.update_time = current_timestamp() + user.update_date = datetime_format(datetime.now()) user.save() msg = "Password reset successful. Logged in." return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 558ba1b0f..ae79b45a6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -25,6 +25,7 @@ import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher +from api.db.services.file_service import FileService from common.constants import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService @@ -380,8 +381,11 @@ def chat(dialog, messages, stream=True, **kwargs): retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] + attachments_= "" if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] + if "files" in messages[-1]: + attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"])) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) @@ -451,7 +455,7 @@ def chat(dialog, messages, stream=True, **kwargs): ), ) - for think in reasoner.thinking(kbinfos, " ".join(questions)): + for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): if isinstance(think, str): thought = think knowledges = [t for t in think.split("\n") if t] @@ -503,7 +507,7 @@ def chat(dialog, messages, stream=True, **kwargs): kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 1fbecdafe..11ef5b454 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -13,10 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import base64 import logging import re +import sys +import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from typing import Union from peewee import fn @@ -520,7 +525,7 @@ class FileService(CommonService): if img_base64 and file_type == FileType.VISUAL.value: return GptV4.image2base64(blob) cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs) - return "\n".join([ck["content_with_weight"] for ck in cks]) + return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks]) @staticmethod def get_parser(doc_type, filename, default): @@ -588,3 +593,80 @@ class FileService(CommonService): errors += str(e) return errors + + @staticmethod + def upload_info(user_id, file, url: str|None=None): + def structured(filename, filetype, blob, content_type): + nonlocal user_id + if filetype == FileType.PDF.value: + blob = read_potential_broken_pdf(blob) + + location = get_uuid() + FileService.put_blob(user_id, location, blob) + + return { + "id": location, + "name": filename, + "size": sys.getsizeof(blob), + "extension": filename.split(".")[-1].lower(), + "mime_type": content_type, + "created_by": user_id, + "created_at": time.time(), + "preview_url": None + } + + if url: + from crawl4ai import ( + AsyncWebCrawler, + BrowserConfig, + CrawlerRunConfig, + DefaultMarkdownGenerator, + PruningContentFilter, + CrawlResult + ) + filename = re.sub(r"\?.*", "", url.split("/")[-1]) + async def adownload(): + browser_config = BrowserConfig( + headless=True, + verbose=False, + ) + async with AsyncWebCrawler(config=browser_config) as crawler: + crawler_config = CrawlerRunConfig( + markdown_generator=DefaultMarkdownGenerator( + content_filter=PruningContentFilter() + ), + pdf=True, + screenshot=False + ) + result: CrawlResult = await crawler.arun( + url=url, + config=crawler_config + ) + return result + page = asyncio.run(adownload()) + if page.pdf: + if filename.split(".")[-1].lower() != "pdf": + filename += ".pdf" + return structured(filename, "pdf", page.pdf, page.response_headers["content-type"]) + + return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id) + + DocumentService.check_doc_health(user_id, file.filename) + return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) + + @staticmethod + def get_files(self, files: Union[None, list[dict]]) -> list[str]: + if not files: + return [] + def image_to_base64(file): + return "data:{};base64,{}".format(file["mime_type"], + base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + exe = ThreadPoolExecutor(max_workers=5) + threads = [] + for file in files: + if file["mime_type"].find("image") >=0: + threads.append(exe.submit(image_to_base64, file)) + continue + threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + return [th.result() for th in threads] + diff --git a/common/data_source/config.py b/common/data_source/config.py index a643e3d41..a3d86720c 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -217,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") +GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index 67ebfae98..1757f4ffe 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -1,6 +1,6 @@ import logging +import os from typing import Any - from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError @@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS from common.data_source.google_util.resource import get_admin_service, get_gmail_service -from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval +from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection -from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc +from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc # Constants for Gmail API fields THREAD_LIST_FIELDS = "nextPageToken, threads(id)" @@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, message_data += f"{name}: {value}\n" message_body_text: str = get_message_body(payload) - return TextSection(link=link, text=message_body_text + message_data), metadata @@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") + semantic_identifier = clean_string(semantic_identifier) + semantic_identifier = sanitize_filename(semantic_identifier) if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") - + updated_at_datetime = None if updated_at: - updated_at_datetime = time_str_to_utc(updated_at) + updated_at_datetime = gmail_time_str_to_utc(updated_at) thread_id = full_thread.get("id") if not thread_id: @@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = "(no subject)" + combined_sections = "\n\n".join( + sec.text for sec in sections if hasattr(sec, "text") + ) + blob = combined_sections + size_bytes = len(blob) + extension = '.txt' + return Document( id=thread_id, semantic_identifier=semantic_identifier, - sections=sections, + blob=blob, + size_bytes=size_bytes, + extension=extension, source=DocumentSource.GMAIL, primary_owners=primary_owners, secondary_owners=secondary_owners, doc_updated_at=updated_at_datetime, - metadata={}, + metadata=message_metadata, external_access=ExternalAccess( external_user_emails={email_used_to_fetch_thread}, external_user_group_ids=set(), @@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): q=query, continue_on_404_or_403=True, ): - full_threads = _execute_single_retrieval( + full_thread = _execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, - list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread["id"], continue_on_404_or_403=True, ) - full_thread = list(full_threads)[0] doc = thread_to_document(full_thread, user_email) if doc is None: continue @@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if __name__ == "__main__": - pass + import time + import os + from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.INFO) + try: + email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com") + creds = get_credentials_from_env(email, oauth=True, source="gmail") + print("Credentials loaded successfully") + print(f"{creds=}") + + connector = GmailConnector(batch_size=2) + print("GmailConnector initialized") + connector.load_credentials(creds) + print("Credentials loaded into connector") + + print("Gmail is ready to use") + + for file in connector._fetch_threads( + int(time.time()) - 1 * 24 * 60 * 60, + int(time.time()), + ): + print("new batch","-"*80) + for f in file: + print(f) + print("\n\n") + except Exception as e: + logging.exception(f"Error loading credentials: {e}") \ No newline at end of file diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index fb88d0ed0..48628f490 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -1,7 +1,6 @@ """Google Drive connector""" import copy -import json import logging import os import sys @@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import ( from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS -from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict @@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) -def get_credentials_from_env(email: str, oauth: bool = False) -> dict: - try: - if oauth: - raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] - else: - raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] - except KeyError: - raise ValueError("Missing Google Drive credentials in environment variables") - - try: - credential_dict = json.loads(raw_credential_string) - except json.JSONDecodeError: - raise ValueError("Invalid JSON in Google Drive credentials") - - if oauth: - credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE) - - refried_credential_string = json.dumps(credential_dict) - - DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" - DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" - DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" - - cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - - return { - cred_key: refried_credential_string, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, - DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", - } - - class CheckpointOutputWrapper: """ Wraps a CheckpointOutput generator to give things back in a more digestible format. @@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector( if __name__ == "__main__": import time - + from common.data_source.google_util.util import get_credentials_from_env logging.basicConfig(level=logging.DEBUG) try: @@ -1245,7 +1210,7 @@ if __name__ == "__main__": creds = get_credentials_from_env(email, oauth=True) print("Credentials loaded successfully") print(f"{creds=}") - + sys.exit(0) connector = GoogleDriveConnector( include_shared_drives=False, shared_drive_urls=None, diff --git a/common/data_source/google_util/constant.py b/common/data_source/google_util/constant.py index 8ab75fa14..858ee31c8 100644 --- a/common/data_source/google_util/constant.py +++ b/common/data_source/google_util/constant.py @@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste SCOPE_INSTRUCTIONS = "" -GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """ +GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """ - Google Drive Authorization + {title}