diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 39c526104..4be4b06be 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: name: ragflow_tests # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://github.com/orgs/community/discussions/26261 - if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }} + if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }} runs-on: [ "self-hosted", "ragflow-test" ] steps: # https://github.com/hmarr/debug-action diff --git a/agent/canvas.py b/agent/canvas.py index 3e15814aa..9e95a5611 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -281,6 +281,7 @@ class Canvas(Graph): "sys.conversation_turns": 0, "sys.files": [] } + self.variables = {} super().__init__(dsl, tenant_id, task_id) def load(self): @@ -295,6 +296,10 @@ class Canvas(Graph): "sys.conversation_turns": 0, "sys.files": [] } + if "variables" in self.dsl: + self.variables = self.dsl["variables"] + else: + self.variables = {} self.retrieval = self.dsl["retrieval"] self.memory = self.dsl.get("memory", []) @@ -311,8 +316,9 @@ class Canvas(Graph): self.history = [] self.retrieval = [] self.memory = [] + print(self.variables) for k in self.globals.keys(): - if k.startswith("sys.") or k.startswith("env."): + if k.startswith("sys."): if isinstance(self.globals[k], str): self.globals[k] = "" elif isinstance(self.globals[k], int): @@ -325,6 +331,29 @@ class Canvas(Graph): self.globals[k] = {} else: self.globals[k] = None + if k.startswith("env."): + key = k[4:] + if key in self.variables: + variable = self.variables[key] + if variable["value"]: + self.globals[k] = variable["value"] + else: + if variable["type"] == "string": + self.globals[k] = "" + elif variable["type"] == "number": + self.globals[k] = 0 + elif variable["type"] == "boolean": + self.globals[k] = False + elif variable["type"] == "object": + self.globals[k] = {} + elif variable["type"].startswith("array"): + self.globals[k] = [] + else: + self.globals[k] = "" + else: + self.globals[k] = "" + print(self.globals) + async def run(self, **kwargs): st = time.perf_counter() @@ -473,7 +502,7 @@ class Canvas(Graph): else: self.error = cpn_obj.error() - if cpn_obj.component_name.lower() != "iteration": + if cpn_obj.component_name.lower() not in ("iteration","loop"): if isinstance(cpn_obj.output("content"), partial): if self.error: cpn_obj.set_output("content", None) @@ -498,14 +527,16 @@ class Canvas(Graph): for cpn_id in cpn_ids: _append_path(cpn_id) - if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end(): + if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end(): iter = cpn_obj.get_parent() yield _node_finished(iter) _extend_path(self.get_component(cpn["parent_id"])["downstream"]) elif cpn_obj.component_name.lower() in ["categorize", "switch"]: _extend_path(cpn_obj.output("_next")) - elif cpn_obj.component_name.lower() == "iteration": + elif cpn_obj.component_name.lower() in ("iteration", "loop"): _append_path(cpn_obj.get_start()) + elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop": + _extend_path(self.get_component(cpn["parent_id"])["downstream"]) elif not cpn["downstream"] and cpn_obj.get_parent(): _append_path(cpn_obj.get_parent().get_start()) else: diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index b3470d6aa..979b636af 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging import os import re @@ -29,7 +30,7 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService from common.connection_utils import timeout from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ - citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in + citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from agent.component.llm import LLMParam, LLM @@ -137,6 +138,29 @@ class Agent(LLM, ToolBase): res.update(cpn.get_input_form()) return res + def _get_output_schema(self): + try: + cand = self._param.outputs.get("structured") + except Exception: + return None + + if isinstance(cand, dict): + if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0: + return cand + for k in ("schema", "structured"): + if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0: + return cand[k] + + return None + + def _force_format_to_schema(self, text: str, schema_prompt: str) -> str: + fmt_msgs = [ + {"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."}, + {"role": "user", "content": text}, + ] + _, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97)) + return self._generate(fmt_msgs) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) def _invoke(self, **kwargs): if self.check_if_canceled("Agent processing"): @@ -160,17 +184,22 @@ class Agent(LLM, ToolBase): return LLM._invoke(self, **kwargs) prompt, msg, user_defined_prompt = self._prepare_prompt_variables() + output_schema = self._get_output_schema() + schema_prompt = "" + if output_schema: + schema = json.dumps(output_schema, ensure_ascii=False, indent=2) + schema_prompt = structured_output_prompt(schema) downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt)) return _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" - for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt): if self.check_if_canceled("Agent processing"): return ans += delta_ans @@ -183,6 +212,28 @@ class Agent(LLM, ToolBase): self.set_output("_ERROR", ans) return + if output_schema: + error = "" + for _ in range(self._param.max_retries + 1): + try: + def clean_formated_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) + obj = json_repair.loads(clean_formated_answer(ans)) + self.set_output("structured", obj) + if use_tools: + self.set_output("use_tools", use_tools) + return obj + except Exception: + error = "The answer cannot be parsed as JSON" + ans = self._force_format_to_schema(ans, schema_prompt) + if ans.find("**ERROR**") >= 0: + continue + + self.set_output("_ERROR", error) + return + self.set_output("content", ans) if use_tools: self.set_output("use_tools", use_tools) @@ -219,7 +270,7 @@ class Agent(LLM, ToolBase): ]): yield delta_ans - def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}): + def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): token_count = 0 tool_metas = self.tool_meta hist = deepcopy(history) @@ -256,9 +307,13 @@ class Agent(LLM, ToolBase): def complete(): nonlocal hist need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 + if schema_prompt: + need2cite = False cited = False - if hist[0]["role"] == "system" and need2cite: - if len(hist) < 7: + if hist and hist[0]["role"] == "system": + if schema_prompt: + hist[0]["content"] += "\n" + schema_prompt + if need2cite and len(hist) < 7: hist[0]["content"] += citation_prompt() cited = True yield "", token_count @@ -369,7 +424,7 @@ Respond immediately with your final comprehensive answer. """ for k in self._param.outputs.keys(): self._param.outputs[k]["value"] = None - + for k, cpn in self.tools.items(): if hasattr(cpn, "reset") and callable(cpn.reset): cpn.reset() diff --git a/agent/component/exit_loop.py b/agent/component/exit_loop.py new file mode 100644 index 000000000..9dc044912 --- /dev/null +++ b/agent/component/exit_loop.py @@ -0,0 +1,32 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase + + +class ExitLoopParam(ComponentParamBase, ABC): + def check(self): + return True + + +class ExitLoop(ComponentBase, ABC): + component_name = "ExitLoop" + + def _invoke(self, **kwargs): + pass + + def thoughts(self) -> str: + return "" \ No newline at end of file diff --git a/agent/component/llm.py b/agent/component/llm.py index 807bbc288..0f5317676 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -222,7 +222,7 @@ class LLM(ComponentBase): output_structure = self._param.outputs['structured'] except Exception: pass - if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"): + if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0: schema=json.dumps(output_structure, ensure_ascii=False, indent=2) prompt += structured_output_prompt(schema) for _ in range(self._param.max_retries+1): diff --git a/agent/component/loop.py b/agent/component/loop.py new file mode 100644 index 000000000..484dfae82 --- /dev/null +++ b/agent/component/loop.py @@ -0,0 +1,80 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase + + +class LoopParam(ComponentParamBase): + """ + Define the Loop component parameters. + """ + + def __init__(self): + super().__init__() + self.loop_variables = [] + self.loop_termination_condition=[] + self.maximum_loop_count = 0 + + def get_input_form(self) -> dict[str, dict]: + return { + "items": { + "type": "json", + "name": "Items" + } + } + + def check(self): + return True + + +class Loop(ComponentBase, ABC): + component_name = "Loop" + + def get_start(self): + for cid in self._canvas.components.keys(): + if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem": + continue + if self._canvas.get_component(cid)["parent_id"] == self._id: + return cid + + def _invoke(self, **kwargs): + if self.check_if_canceled("Loop processing"): + return + + for item in self._param.loop_variables: + if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]): + assert "Loop Variable is not complete." + if item["input_mode"]=="variable": + self.set_output(item["variable"],self._canvas.get_variable_value(item["value"])) + elif item["input_mode"]=="constant": + self.set_output(item["variable"],item["value"]) + else: + if item["type"] == "number": + self.set_output(item["variable"], 0) + elif item["type"] == "string": + self.set_output(item["variable"], "") + elif item["type"] == "boolean": + self.set_output(item["variable"], False) + elif item["type"].startswith("object"): + self.set_output(item["variable"], {}) + elif item["type"].startswith("array"): + self.set_output(item["variable"], []) + else: + self.set_output(item["variable"], "") + + + def thoughts(self) -> str: + return "Loop from canvas." \ No newline at end of file diff --git a/agent/component/loopitem.py b/agent/component/loopitem.py new file mode 100644 index 000000000..71b91c810 --- /dev/null +++ b/agent/component/loopitem.py @@ -0,0 +1,163 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase + + +class LoopItemParam(ComponentParamBase): + """ + Define the LoopItem component parameters. + """ + def check(self): + return True + +class LoopItem(ComponentBase, ABC): + component_name = "LoopItem" + + def __init__(self, canvas, id, param: ComponentParamBase): + super().__init__(canvas, id, param) + self._idx = 0 + + + def _invoke(self, **kwargs): + if self.check_if_canceled("LoopItem processing"): + return + parent = self.get_parent() + maximum_loop_count = parent._param.maximum_loop_count + if self._idx >= maximum_loop_count: + self._idx = -1 + return + if self._idx > 0: + if self.check_if_canceled("LoopItem processing"): + return + self._idx += 1 + + def evaluate_condition(self,var, operator, value): + if isinstance(var, str): + if operator == "contains": + return value in var + elif operator == "not contains": + return value not in var + elif operator == "start with": + return var.startswith(value) + elif operator == "end with": + return var.endswith(value) + elif operator == "is": + return var == value + elif operator == "is not": + return var != value + elif operator == "empty": + return var == "" + elif operator == "not empty": + return var != "" + + elif isinstance(var, (int, float)): + if operator == "=": + return var == value + elif operator == "≠": + return var != value + elif operator == ">": + return var > value + elif operator == "<": + return var < value + elif operator == "≥": + return var >= value + elif operator == "≤": + return var <= value + elif operator == "empty": + return var is None + elif operator == "not empty": + return var is not None + + elif isinstance(var, bool): + if operator == "is": + return var is value + elif operator == "is not": + return var is not value + elif operator == "empty": + return var is None + elif operator == "not empty": + return var is not None + + elif isinstance(var, dict): + if operator == "empty": + return len(var) == 0 + elif operator == "not empty": + return len(var) > 0 + + elif isinstance(var, list): + if operator == "contains": + return value in var + elif operator == "not contains": + return value not in var + + elif operator == "is": + return var == value + elif operator == "is not": + return var != value + + elif operator == "empty": + return len(var) == 0 + elif operator == "not empty": + return len(var) > 0 + + raise Exception(f"Invalid operator: {operator}") + + def end(self): + if self._idx == -1: + return True + parent = self.get_parent() + logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and" + conditions = [] + for item in parent._param.loop_termination_condition: + if not item.get("variable") or not item.get("operator"): + raise ValueError("Loop condition is incomplete.") + var = self._canvas.get_variable_value(item["variable"]) + operator = item["operator"] + input_mode = item.get("input_mode", "constant") + + if input_mode == "variable": + value = self._canvas.get_variable_value(item.get("value", "")) + elif input_mode == "constant": + value = item.get("value", "") + else: + raise ValueError("Invalid input mode.") + conditions.append(self.evaluate_condition(var, operator, value)) + should_end = ( + all(conditions) if logical_operator == "and" + else any(conditions) if logical_operator == "or" + else None + ) + if should_end is None: + raise ValueError("Invalid logical operator,should be 'and' or 'or'.") + + if should_end: + self._idx = -1 + return True + + return False + + def next(self): + if self._idx == -1: + self._idx = 0 + else: + self._idx += 1 + if self._idx >= len(self._items): + self._idx = -1 + return False + + def thoughts(self) -> str: + return "Next turn..." \ No newline at end of file diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 4932c3fcc..34da2293b 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -28,8 +28,8 @@ from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from common.constants import RetCode, TaskStatus -from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource -from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES +from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource +from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.misc_utils import get_uuid from rag.utils.redis_conn import REDIS_CONN from api.apps import login_required, current_user @@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result" WEB_FLOW_TTL_SECS = 15 * 60 -def _web_state_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}" +def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth state. + + The default prefix keeps backward compatibility for Google Drive. + When source_type == "gmail", a different prefix is used so that + Drive/Gmail flows don't clash in Redis. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_state" + else: + prefix = GOOGLE_WEB_FLOW_STATE_PREFIX + return f"{prefix}:{flow_id}" -def _web_result_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}" +def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth result. + + Mirrors _web_state_cache_key logic for result storage. + """ + if source_type == "gmail": + prefix = "gmail_web_flow_result" + else: + prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX + return f"{prefix}:{flow_id}" def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]: @@ -146,19 +164,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} -async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): +async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" escaped_message = escape(message) payload_json = json.dumps( { - "type": "ragflow-google-drive-oauth", + # TODO(google-oauth): include connector type (drive/gmail) in payload type if needed + "type": f"ragflow-google-{source}-oauth", "status": status, "flowId": flow_id or "", "message": message, } ) - html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format( + # TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type + html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format( + title=f"Google {source.capitalize()} Authorization", heading="Authorization complete" if success else "Authorization failed", message=escaped_message, payload_json=payload_json, @@ -169,20 +190,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str): return response -@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") -async def start_google_drive_web_oauth(): - if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: +async def start_google_web_oauth(): + source = request.args.get("type", "google-drive") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") + + if source == "gmail": + redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GMAIL] + else: + redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL] + + if not redirect_uri: return get_json_result( code=RetCode.SERVER_ERROR, - message="Google Drive OAuth redirect URI is not configured on the server.", + message="Google OAuth redirect URI is not configured on the server.", ) req = await request.json or {} raw_credentials = req.get("credentials", "") + try: credentials = _load_credentials(raw_credentials) + print(credentials) except ValueError as exc: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) @@ -199,8 +233,8 @@ async def start_google_drive_web_oauth(): flow_id = str(uuid.uuid4()) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri authorization_url, _ = flow.authorization_url( access_type="offline", include_granted_scopes="true", @@ -219,7 +253,7 @@ async def start_google_drive_web_oauth(): "client_config": client_config, "created_at": int(time.time()), } - REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) return get_json_result( data={ @@ -230,60 +264,122 @@ async def start_google_drive_web_oauth(): ) -@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 -async def google_drive_web_oauth_callback(): +@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_gmail_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") + source = "gmail" + if source != 'gmail': + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + error_description = request.args.get("error_description") or error if not state_id: - return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.") + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) - state_cache = REDIS_CONN.get(_web_state_cache_key(state_id)) + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) if not state_cache: - return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.") + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") if not client_config: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) if error: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) code = request.args.get("code") if not code: - return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.") + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) + flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI flow.fetch_token(code=code) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) creds_json = flow.credentials.to_json() result_payload = { "user_id": state_obj.get("user_id"), "credentials": creds_json, } - REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS) - REDIS_CONN.delete(_web_state_cache_key(state_id)) + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) - return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.") + print("\n\n", _web_result_cache_key(state_id, source), "\n\n") + + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_drive_web_oauth_callback(): + state_id = request.args.get("state") + error = request.args.get("error") + source = "google-drive" + if source not in ("google-drive", "gmail"): + return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source) + + error_description = request.args.get("error_description") or error + + if not state_id: + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) + + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) + if not state_cache: + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) + + state_obj = json.loads(state_cache) + client_config = state_obj.get("client_config") + if not client_config: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) + + if error: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) + + code = request.args.get("code") + if not code: + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) + + try: + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) + flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow.fetch_token(code=code) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to exchange Google OAuth code: %s", exc) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) + + creds_json = flow.credentials.to_json() + result_payload = { + "user_id": state_obj.get("user_id"), + "credentials": creds_json, + } + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) + +@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") -async def poll_google_drive_web_result(): +async def poll_google_web_result(): req = await request.json or {} + source = request.args.get("type") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") flow_id = req.get("flow_id") - cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) + cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source)) if not cache_raw: return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") @@ -291,5 +387,5 @@ async def poll_google_drive_web_result(): if result.get("user_id") != current_user.id: return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") - REDIS_CONN.delete(_web_result_cache_key(flow_id)) + REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 1682a0285..6377ea7c8 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -31,7 +31,7 @@ from api.db.services.file_service import FileService from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type from common import settings - +from common.constants import RetCode @manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required @@ -86,19 +86,19 @@ async def upload(tenant_id): pf_id = root_folder["id"] if 'file' not in files: - return get_json_result(data=False, message='No file part!', code=400) + return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST) file_objs = files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': - return get_json_result(data=False, message='No selected file!', code=400) + return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST) file_res = [] try: e, pf_folder = FileService.get_by_id(pf_id) if not e: - return get_json_result(data=False, message="Can't find this folder!", code=404) + return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND) for file_obj in file_objs: # Handle file path @@ -114,13 +114,13 @@ async def upload(tenant_id): if file_len != len_id_list: e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) if not e: - return get_json_result(data=False, message="Folder not found!", code=404) + return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND) last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) else: e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) if not e: - return get_json_result(data=False, message="Folder not found!", code=404) + return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND) last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) @@ -202,7 +202,7 @@ async def create(tenant_id): try: if not FileService.is_parent_folder_exist(pf_id): - return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400) + return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST) if FileService.query(name=req["name"], parent_id=pf_id): return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409) @@ -306,13 +306,13 @@ def list_files(tenant_id): try: e, file = FileService.get_by_id(pf_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords) parent_folder = FileService.get_parent_folder(pf_id) if not parent_folder: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()}) except Exception as e: @@ -392,7 +392,7 @@ def get_parent_folder(): try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) parent_folder = FileService.get_parent_folder(file_id) return get_json_result(data={"parent_folder": parent_folder.to_json()}) @@ -439,7 +439,7 @@ def get_all_parent_folders(tenant_id): try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) parent_folders = FileService.get_all_parent_folders(file_id) parent_folders_res = [folder.to_json() for folder in parent_folders] @@ -487,34 +487,34 @@ async def rm(tenant_id): for file_id in file_ids: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="File or Folder not found!", code=404) + return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND) if not file.tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if file.type == FileType.FOLDER.value: file_id_list = FileService.get_all_innermost_file_ids(file_id, []) for inner_file_id in file_id_list: e, file = FileService.get_by_id(inner_file_id) if not e: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) settings.STORAGE_IMPL.rm(file.parent_id, file.location) FileService.delete_folder_by_pf_id(tenant_id, file_id) else: settings.STORAGE_IMPL.rm(file.parent_id, file.location) if not FileService.delete(file): - return get_json_result(message="Database error (File removal)!", code=500) + return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR) informs = File2DocumentService.get_by_file_id(file_id) for inform in informs: doc_id = inform.document_id e, doc = DocumentService.get_by_id(doc_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if not DocumentService.remove_document(doc, tenant_id): - return get_json_result(message="Database error (Document removal)!", code=500) + return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR) File2DocumentService.delete_by_file_id(file_id) return get_json_result(data=True) @@ -560,23 +560,23 @@ async def rename(tenant_id): try: e, file = FileService.get_by_id(req["file_id"]) if not e: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path( file.name.lower()).suffix: - return get_json_result(data=False, message="The extension of file can't be changed", code=400) + return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST) for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id): if existing_file.name == req["name"]: return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409) if not FileService.update_by_id(req["file_id"], {"name": req["name"]}): - return get_json_result(message="Database error (File rename)!", code=500) + return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR) informs = File2DocumentService.get_by_file_id(req["file_id"]) if informs: if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}): - return get_json_result(message="Database error (Document rename)!", code=500) + return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR) return get_json_result(data=True) except Exception as e: @@ -606,13 +606,13 @@ async def get(tenant_id, file_id): description: File stream schema: type: file - 404: + RetCode.NOT_FOUND: description: File not found """ try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) if not blob: @@ -677,13 +677,13 @@ async def move(tenant_id): for file_id in file_ids: file = files_dict[file_id] if not file: - return get_json_result(message="File or Folder not found!", code=404) + return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND) if not file.tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) fe, _ = FileService.get_by_id(parent_id) if not fe: - return get_json_result(message="Parent Folder not found!", code=404) + return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND) FileService.move_file(file_ids, parent_id) return get_json_result(data=True) @@ -705,7 +705,7 @@ async def convert(tenant_id): for file_id in file_ids: file = files_set[file_id] if not file: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) file_ids_list = [file_id] if file.type == FileType.FOLDER.value: file_ids_list = FileService.get_all_innermost_file_ids(file_id, []) @@ -716,13 +716,13 @@ async def convert(tenant_id): doc_id = inform.document_id e, doc = DocumentService.get_by_id(doc_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if not DocumentService.remove_document(doc, tenant_id): return get_json_result( - message="Database error (Document removal)!", code=404) + message="Database error (Document removal)!", code=RetCode.NOT_FOUND) File2DocumentService.delete_by_file_id(id) # insert @@ -730,11 +730,11 @@ async def convert(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return get_json_result( - message="Can't find this knowledgebase!", code=404) + message="Can't find this knowledgebase!", code=RetCode.NOT_FOUND) e, file = FileService.get_by_id(id) if not e: return get_json_result( - message="Can't find this file!", code=404) + message="Can't find this file!", code=RetCode.NOT_FOUND) doc = DocumentService.insert({ "id": get_uuid(), diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 46e114fdf..ae1355da8 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -121,8 +121,8 @@ async def login(): response_data = user.to_json() user.access_token = get_uuid() login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) + user.update_time = current_timestamp() + user.update_date = datetime_format(datetime.now()) user.save() msg = "Welcome back!" @@ -1002,8 +1002,8 @@ async def forget(): # Auto login (reuse login flow) user.access_token = get_uuid() login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) + user.update_time = current_timestamp() + user.update_date = datetime_format(datetime.now()) user.save() msg = "Password reset successful. Logged in." return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) diff --git a/api/db/db_models.py b/api/db/db_models.py index bd3feea64..e60afbef5 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -749,7 +749,7 @@ class Knowledgebase(DataBaseModel): parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True) pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True) - parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0}) pagerank = IntegerField(default=0, index=False) graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True) @@ -774,7 +774,7 @@ class Document(DataBaseModel): kb_id = CharField(max_length=256, null=False, index=True) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True) pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True) - parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0}) source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True) type = CharField(max_length=32, null=False, help_text="file extension", index=True) created_by = CharField(max_length=32, null=False, help_text="who created it", index=True) diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index db8c713e2..660530c82 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -214,9 +214,21 @@ class SyncLogsService(CommonService): err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) + # Create a mapping from filename to metadata for later use + metadata_map = {} + for d in docs: + if d.get("metadata"): + filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "") + metadata_map[filename] = d["metadata"] + kb_table_num_map = {} for doc, _ in doc_blob_pairs: doc_ids.append(doc["id"]) + + # Set metadata if available for this document + if doc["name"] in metadata_map: + DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]}) + if not auto_parse or auto_parse == "0": continue DocumentService.run(tenant_id, doc, kb_table_num_map) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 514b3fd87..7b7ef53ec 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -923,7 +923,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email } - parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} + parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0} exe = ThreadPoolExecutor(max_workers=12) threads = [] doc_nm = {} diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index cbd2423f2..314211694 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -313,6 +313,10 @@ def get_parser_config(chunk_method, parser_config): chunk_method = "naive" # Define default configurations for each chunking method + base_defaults = { + "table_context_size": 0, + "image_context_size": 0, + } key_mapping = { "naive": { "layout_recognize": "DeepDOC", @@ -365,16 +369,19 @@ def get_parser_config(chunk_method, parser_config): default_config = key_mapping[chunk_method] - # If no parser_config provided, return default + # If no parser_config provided, return default merged with base defaults if not parser_config: - return default_config + if default_config is None: + return deep_merge(base_defaults, {}) + return deep_merge(base_defaults, default_config) # If parser_config is provided, merge with defaults to ensure required fields exist if default_config is None: - return parser_config + return deep_merge(base_defaults, parser_config) # Ensure raptor and graphrag fields have default values if not provided - merged_config = deep_merge(default_config, parser_config) + merged_config = deep_merge(base_defaults, default_config) + merged_config = deep_merge(merged_config, parser_config) return merged_config diff --git a/common/constants.py b/common/constants.py index a09832eed..574786d00 100644 --- a/common/constants.py +++ b/common/constants.py @@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum): RUNNING = 106 PERMISSION_ERROR = 108 AUTHENTICATION_ERROR = 109 + BAD_REQUEST = 400 UNAUTHORIZED = 401 SERVER_ERROR = 500 FORBIDDEN = 403 diff --git a/common/data_source/config.py b/common/data_source/config.py index a643e3d41..a3d86720c 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -217,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") +GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index 821f79862..a7935ff6d 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -1562,6 +1562,7 @@ class ConfluenceConnector( size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes doc_updated_at=datetime_from_string(page["version"]["when"]), primary_owners=primary_owners if primary_owners else None, + metadata=metadata if metadata else None, ) except Exception as e: logging.error(f"Error converting page {page.get('id', 'unknown')}: {e}") diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 93a0477b0..46b23443c 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -65,6 +65,7 @@ def _convert_message_to_document( blob=message.content.encode("utf-8"), extension=".txt", size_bytes=len(message.content.encode("utf-8")), + metadata=metadata if metadata else None, ) diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index 67ebfae98..1757f4ffe 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -1,6 +1,6 @@ import logging +import os from typing import Any - from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError @@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS from common.data_source.google_util.resource import get_admin_service, get_gmail_service -from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval +from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection -from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc +from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc # Constants for Gmail API fields THREAD_LIST_FIELDS = "nextPageToken, threads(id)" @@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, message_data += f"{name}: {value}\n" message_body_text: str = get_message_body(payload) - return TextSection(link=link, text=message_body_text + message_data), metadata @@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") + semantic_identifier = clean_string(semantic_identifier) + semantic_identifier = sanitize_filename(semantic_identifier) if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") - + updated_at_datetime = None if updated_at: - updated_at_datetime = time_str_to_utc(updated_at) + updated_at_datetime = gmail_time_str_to_utc(updated_at) thread_id = full_thread.get("id") if not thread_id: @@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = "(no subject)" + combined_sections = "\n\n".join( + sec.text for sec in sections if hasattr(sec, "text") + ) + blob = combined_sections + size_bytes = len(blob) + extension = '.txt' + return Document( id=thread_id, semantic_identifier=semantic_identifier, - sections=sections, + blob=blob, + size_bytes=size_bytes, + extension=extension, source=DocumentSource.GMAIL, primary_owners=primary_owners, secondary_owners=secondary_owners, doc_updated_at=updated_at_datetime, - metadata={}, + metadata=message_metadata, external_access=ExternalAccess( external_user_emails={email_used_to_fetch_thread}, external_user_group_ids=set(), @@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): q=query, continue_on_404_or_403=True, ): - full_threads = _execute_single_retrieval( + full_thread = _execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, - list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread["id"], continue_on_404_or_403=True, ) - full_thread = list(full_threads)[0] doc = thread_to_document(full_thread, user_email) if doc is None: continue @@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): if __name__ == "__main__": - pass + import time + import os + from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.INFO) + try: + email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com") + creds = get_credentials_from_env(email, oauth=True, source="gmail") + print("Credentials loaded successfully") + print(f"{creds=}") + + connector = GmailConnector(batch_size=2) + print("GmailConnector initialized") + connector.load_credentials(creds) + print("Credentials loaded into connector") + + print("Gmail is ready to use") + + for file in connector._fetch_threads( + int(time.time()) - 1 * 24 * 60 * 60, + int(time.time()), + ): + print("new batch","-"*80) + for f in file: + print(f) + print("\n\n") + except Exception as e: + logging.exception(f"Error loading credentials: {e}") \ No newline at end of file diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index fb88d0ed0..48628f490 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -1,7 +1,6 @@ """Google Drive connector""" import copy -import json import logging import os import sys @@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import ( from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS -from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict @@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) -def get_credentials_from_env(email: str, oauth: bool = False) -> dict: - try: - if oauth: - raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] - else: - raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] - except KeyError: - raise ValueError("Missing Google Drive credentials in environment variables") - - try: - credential_dict = json.loads(raw_credential_string) - except json.JSONDecodeError: - raise ValueError("Invalid JSON in Google Drive credentials") - - if oauth: - credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE) - - refried_credential_string = json.dumps(credential_dict) - - DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" - DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" - DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" - - cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - - return { - cred_key: refried_credential_string, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, - DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", - } - - class CheckpointOutputWrapper: """ Wraps a CheckpointOutput generator to give things back in a more digestible format. @@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector( if __name__ == "__main__": import time - + from common.data_source.google_util.util import get_credentials_from_env logging.basicConfig(level=logging.DEBUG) try: @@ -1245,7 +1210,7 @@ if __name__ == "__main__": creds = get_credentials_from_env(email, oauth=True) print("Credentials loaded successfully") print(f"{creds=}") - + sys.exit(0) connector = GoogleDriveConnector( include_shared_drives=False, shared_drive_urls=None, diff --git a/common/data_source/google_util/constant.py b/common/data_source/google_util/constant.py index 8ab75fa14..858ee31c8 100644 --- a/common/data_source/google_util/constant.py +++ b/common/data_source/google_util/constant.py @@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste SCOPE_INSTRUCTIONS = "" -GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """ +GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """
-