diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index c3a01e517..fd9096cb6 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -137,7 +137,7 @@ class Retrieval(ToolBase, ABC): if not doc_ids: doc_ids = None elif self._param.meta_data_filter.get("method") == "manual": - filters=self._param.meta_data_filter["manual"] + filters = self._param.meta_data_filter["manual"] for flt in filters: pat = re.compile(self.variable_ref_patt) s = flt["value"] @@ -166,8 +166,8 @@ class Retrieval(ToolBase, ABC): out_parts.append(s[last:]) flt["value"] = "".join(out_parts) doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and"))) - if not doc_ids: - doc_ids = None + if filters and not doc_ids: + doc_ids = ["-999"] if self._param.cross_languages: query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index e121bcba7..b43fb9af1 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -311,8 +311,8 @@ async def retrieval_test(): doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if not doc_ids: - doc_ids = None + if meta_data_filter["manual"] and not doc_ids: + doc_ids = ["-999"] try: tenants = UserTenantService.query(user_id=current_user.id) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 84300ac3c..30fbd835e 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1445,6 +1445,8 @@ async def retrieval_test(tenant_id): metadata_condition = req.get("metadata_condition", {}) or {} metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + if metadata_condition and not doc_ids: + doc_ids = ["-999"] similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 533375622..074401ede 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -446,8 +446,8 @@ async def agent_completions(tenant_id, agent_id): if req.get("stream", True): - def generate(): - for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + async def generate(): + async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): if isinstance(answer, str): try: ans = json.loads(answer[5:]) # remove "data:" @@ -471,7 +471,7 @@ async def agent_completions(tenant_id, agent_id): full_content = "" reference = {} final_ans = "" - for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): try: ans = json.loads(answer[5:]) @@ -873,7 +873,7 @@ async def agent_bot_completions(agent_id): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + async for answer in agent_completion(objs[0].tenant_id, agent_id, **req): return get_result(data=answer) @@ -981,8 +981,8 @@ async def retrieval_test_embedded(): doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if not doc_ids: - doc_ids = None + if meta_data_filter["manual"] and not doc_ids: + doc_ids = ["-999"] try: tenants = UserTenantService.query(user_id=tenant_id) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index db878574d..0a09ea532 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -415,9 +415,10 @@ def chat(dialog, messages, stream=True, **kwargs): if not attachments: attachments = None elif dialog.meta_data_filter.get("method") == "manual": - attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"], dialog.meta_data_filter.get("logic", "and"))) - if not attachments: - attachments = None + conds = dialog.meta_data_filter["manual"] + attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and"))) + if conds and not attachments: + attachments = ["-999"] if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) @@ -787,8 +788,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if not doc_ids: - doc_ids = None + if meta_data_filter["manual"] and not doc_ids: + doc_ids = ["-999"] kbinfos = retriever.retrieval( question=question, @@ -862,8 +863,8 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if not doc_ids: - doc_ids = None + if meta_data_filter["manual"] and not doc_ids: + doc_ids = ["-999"] ranks = settings.retriever.retrieval( question=question, diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py index 8c6a522ad..e29bbbe76 100644 --- a/common/data_source/notion_connector.py +++ b/common/data_source/notion_connector.py @@ -1,38 +1,45 @@ +import html import logging from collections.abc import Generator +from datetime import datetime, timezone +from pathlib import Path from typing import Any, Optional +from urllib.parse import urlparse + from retry import retry from common.data_source.config import ( INDEX_BATCH_SIZE, - DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP + NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP, + DocumentSource, +) +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, + CredentialExpiredError, + InsufficientPermissionsError, + UnexpectedValidationError, ) from common.data_source.interfaces import ( LoadConnector, PollConnector, - SecondsSinceUnixEpoch + SecondsSinceUnixEpoch, ) from common.data_source.models import ( Document, - TextSection, GenerateDocumentsOutput -) -from common.data_source.exceptions import ( - ConnectorValidationError, - CredentialExpiredError, - InsufficientPermissionsError, - UnexpectedValidationError, ConnectorMissingCredentialError -) -from common.data_source.models import ( - NotionPage, + GenerateDocumentsOutput, NotionBlock, - NotionSearchResponse + NotionPage, + NotionSearchResponse, + TextSection, ) from common.data_source.utils import ( - rl_requests, batch_generator, + datetime_from_string, fetch_notion_data, + filter_pages_by_time, properties_to_str, - filter_pages_by_time, datetime_from_string + rl_requests, ) @@ -61,11 +68,9 @@ class NotionConnector(LoadConnector, PollConnector): self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id) @retry(tries=3, delay=1, backoff=2) - def _fetch_child_blocks( - self, block_id: str, cursor: Optional[str] = None - ) -> dict[str, Any] | None: + def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None: """Fetch all child blocks via the Notion API.""" - logging.debug(f"Fetching children of block with ID '{block_id}'") + logging.debug(f"[Notion]: Fetching children of block with ID {block_id}") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" query_params = {"start_cursor": cursor} if cursor else None @@ -79,49 +84,42 @@ class NotionConnector(LoadConnector, PollConnector): response.raise_for_status() return response.json() except Exception as e: - if hasattr(e, 'response') and e.response.status_code == 404: - logging.error( - f"Unable to access block with ID '{block_id}'. " - f"This is likely due to the block not being shared with the integration." - ) + if hasattr(e, "response") and e.response.status_code == 404: + logging.error(f"[Notion]: Unable to access block with ID {block_id}. This is likely due to the block not being shared with the integration.") return None else: - logging.exception(f"Error fetching blocks: {e}") + logging.exception(f"[Notion]: Error fetching blocks: {e}") raise @retry(tries=3, delay=1, backoff=2) def _fetch_page(self, page_id: str) -> NotionPage: """Fetch a page from its ID via the Notion API.""" - logging.debug(f"Fetching page for ID '{page_id}'") + logging.debug(f"[Notion]: Fetching page for ID {page_id}") page_url = f"https://api.notion.com/v1/pages/{page_id}" try: data = fetch_notion_data(page_url, self.headers, "GET") return NotionPage(**data) except Exception as e: - logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}") + logging.warning(f"[Notion]: Failed to fetch page, trying database for ID {page_id}: {e}") return self._fetch_database_as_page(page_id) @retry(tries=3, delay=1, backoff=2) def _fetch_database_as_page(self, database_id: str) -> NotionPage: """Attempt to fetch a database as a page.""" - logging.debug(f"Fetching database for ID '{database_id}' as a page") + logging.debug(f"[Notion]: Fetching database for ID {database_id} as a page") database_url = f"https://api.notion.com/v1/databases/{database_id}" data = fetch_notion_data(database_url, self.headers, "GET") database_name = data.get("title") - database_name = ( - database_name[0].get("text", {}).get("content") if database_name else None - ) + database_name = database_name[0].get("text", {}).get("content") if database_name else None return NotionPage(**data, database_name=database_name) @retry(tries=3, delay=1, backoff=2) - def _fetch_database( - self, database_id: str, cursor: Optional[str] = None - ) -> dict[str, Any]: + def _fetch_database(self, database_id: str, cursor: Optional[str] = None) -> dict[str, Any]: """Fetch a database from its ID via the Notion API.""" - logging.debug(f"Fetching database for ID '{database_id}'") + logging.debug(f"[Notion]: Fetching database for ID {database_id}") block_url = f"https://api.notion.com/v1/databases/{database_id}/query" body = {"start_cursor": cursor} if cursor else None @@ -129,17 +127,12 @@ class NotionConnector(LoadConnector, PollConnector): data = fetch_notion_data(block_url, self.headers, "POST", body) return data except Exception as e: - if hasattr(e, 'response') and e.response.status_code in [404, 400]: - logging.error( - f"Unable to access database with ID '{database_id}'. " - f"This is likely due to the database not being shared with the integration." - ) + if hasattr(e, "response") and e.response.status_code in [404, 400]: + logging.error(f"[Notion]: Unable to access database with ID {database_id}. This is likely due to the database not being shared with the integration.") return {"results": [], "next_cursor": None} raise - def _read_pages_from_database( - self, database_id: str - ) -> tuple[list[NotionBlock], list[str]]: + def _read_pages_from_database(self, database_id: str) -> tuple[list[NotionBlock], list[str]]: """Returns a list of top level blocks and all page IDs in the database.""" result_blocks: list[NotionBlock] = [] result_pages: list[str] = [] @@ -158,10 +151,10 @@ class NotionConnector(LoadConnector, PollConnector): if self.recursive_index_enabled: if obj_type == "page": - logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'") + logging.debug(f"[Notion]: Found page with ID {obj_id} in database {database_id}") result_pages.append(result["id"]) elif obj_type == "database": - logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'") + logging.debug(f"[Notion]: Found database with ID {obj_id} in database {database_id}") _, child_pages = self._read_pages_from_database(obj_id) result_pages.extend(child_pages) @@ -172,44 +165,229 @@ class NotionConnector(LoadConnector, PollConnector): return result_blocks, result_pages - def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]: - """Reads all child blocks for the specified block, returns blocks and child page ids.""" + def _extract_rich_text(self, rich_text_array: list[dict[str, Any]]) -> str: + collected_text: list[str] = [] + for rich_text in rich_text_array: + content = "" + r_type = rich_text.get("type") + + if r_type == "equation": + expr = rich_text.get("equation", {}).get("expression") + if expr: + content = expr + elif r_type == "mention": + mention = rich_text.get("mention", {}) or {} + mention_type = mention.get("type") + mention_value = mention.get(mention_type, {}) if mention_type else {} + if mention_type == "date": + start = mention_value.get("start") + end = mention_value.get("end") + if start and end: + content = f"{start} - {end}" + elif start: + content = start + elif mention_type in {"page", "database"}: + content = mention_value.get("id", rich_text.get("plain_text", "")) + elif mention_type == "link_preview": + content = mention_value.get("url", rich_text.get("plain_text", "")) + else: + content = rich_text.get("plain_text", "") or str(mention_value) + else: + if rich_text.get("plain_text"): + content = rich_text["plain_text"] + elif "text" in rich_text and rich_text["text"].get("content"): + content = rich_text["text"]["content"] + + href = rich_text.get("href") + if content and href: + content = f"{content} ({href})" + + if content: + collected_text.append(content) + + return "".join(collected_text).strip() + + def _build_table_html(self, table_block_id: str) -> str | None: + rows: list[str] = [] + cursor = None + while True: + data = self._fetch_child_blocks(table_block_id, cursor) + if data is None: + break + + for result in data["results"]: + if result.get("type") != "table_row": + continue + cells_html: list[str] = [] + for cell in result["table_row"].get("cells", []): + cell_text = self._extract_rich_text(cell) + cell_html = html.escape(cell_text) if cell_text else "" + cells_html.append(f"{cell_html}") + rows.append(f"{''.join(cells_html)}") + + if data.get("next_cursor") is None: + break + cursor = data["next_cursor"] + + if not rows: + return None + return "\n" + "\n".join(rows) + "\n
" + + def _download_file(self, url: str) -> bytes | None: + try: + response = rl_requests.get(url, timeout=60) + response.raise_for_status() + return response.content + except Exception as exc: + logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}") + return None + + def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]: + file_source_type = result_obj.get("type") + file_source = result_obj.get(file_source_type, {}) if file_source_type else {} + url = file_source.get("url") + + name = result_obj.get("name") or file_source.get("name") + if url and not name: + parsed_name = Path(urlparse(url).path).name + name = parsed_name or f"notion_file_{block_id}" + elif not name: + name = f"notion_file_{block_id}" + + caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None + + return url, name, caption + + def _build_attachment_document( + self, + block_id: str, + url: str, + name: str, + caption: Optional[str], + page_last_edited_time: Optional[str], + ) -> Document | None: + file_bytes = self._download_file(url) + if file_bytes is None: + return None + + extension = Path(name).suffix or Path(urlparse(url).path).suffix or ".bin" + if extension and not extension.startswith("."): + extension = f".{extension}" + if not extension: + extension = ".bin" + + updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc) + semantic_identifier = caption or name or f"Notion file {block_id}" + + return Document( + id=block_id, + blob=file_bytes, + source=DocumentSource.NOTION, + semantic_identifier=semantic_identifier, + extension=extension, + size_bytes=len(file_bytes), + doc_updated_at=updated_at, + ) + + def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]: result_blocks: list[NotionBlock] = [] child_pages: list[str] = [] + attachments: list[Document] = [] cursor = None while True: data = self._fetch_child_blocks(base_block_id, cursor) if data is None: - return result_blocks, child_pages + return result_blocks, child_pages, attachments for result in data["results"]: - logging.debug(f"Found child block for block with ID '{base_block_id}': {result}") + logging.debug(f"[Notion]: Found child block for block with ID {base_block_id}: {result}") result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] if result_type in ["ai_block", "unsupported", "external_object_instance_page"]: - logging.warning(f"Skipping unsupported block type '{result_type}'") + logging.warning(f"[Notion]: Skipping unsupported block type {result_type}") + continue + + if result_type == "table": + table_html = self._build_table_html(result_block_id) + if table_html: + result_blocks.append( + NotionBlock( + id=result_block_id, + text=table_html, + prefix="\n\n", + ) + ) + continue + + if result_type == "equation": + expr = result_obj.get("expression") + if expr: + result_blocks.append( + NotionBlock( + id=result_block_id, + text=expr, + prefix="\n", + ) + ) continue cur_result_text_arr = [] if "rich_text" in result_obj: - for rich_text in result_obj["rich_text"]: - if "text" in rich_text: - text = rich_text["text"]["content"] - cur_result_text_arr.append(text) + text = self._extract_rich_text(result_obj["rich_text"]) + if text: + cur_result_text_arr.append(text) + + if result_type == "bulleted_list_item": + if cur_result_text_arr: + cur_result_text_arr[0] = f"- {cur_result_text_arr[0]}" + else: + cur_result_text_arr = ["- "] + + if result_type == "numbered_list_item": + if cur_result_text_arr: + cur_result_text_arr[0] = f"1. {cur_result_text_arr[0]}" + else: + cur_result_text_arr = ["1. "] + + if result_type == "to_do": + checked = result_obj.get("checked") + checkbox_prefix = "[x]" if checked else "[ ]" + if cur_result_text_arr: + cur_result_text_arr = [f"{checkbox_prefix} {cur_result_text_arr[0]}"] + cur_result_text_arr[1:] + else: + cur_result_text_arr = [checkbox_prefix] + + if result_type in {"file", "image", "pdf", "video", "audio"}: + file_url, file_name, caption = self._extract_file_metadata(result_obj, result_block_id) + if file_url: + attachment_doc = self._build_attachment_document( + block_id=result_block_id, + url=file_url, + name=file_name, + caption=caption, + page_last_edited_time=page_last_edited_time, + ) + if attachment_doc: + attachments.append(attachment_doc) + + attachment_label = caption or file_name + if attachment_label: + cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}") if result["has_children"]: if result_type == "child_page": child_pages.append(result_block_id) else: - logging.debug(f"Entering sub-block: {result_block_id}") - subblocks, subblock_child_pages = self._read_blocks(result_block_id) - logging.debug(f"Finished sub-block: {result_block_id}") + logging.debug(f"[Notion]: Entering sub-block: {result_block_id}") + subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time) + logging.debug(f"[Notion]: Finished sub-block: {result_block_id}") result_blocks.extend(subblocks) child_pages.extend(subblock_child_pages) + attachments.extend(subblock_attachments) if result_type == "child_database": inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id) @@ -231,7 +409,7 @@ class NotionConnector(LoadConnector, PollConnector): cursor = data["next_cursor"] - return result_blocks, child_pages + return result_blocks, child_pages, attachments def _read_page_title(self, page: NotionPage) -> Optional[str]: """Extracts the title from a Notion page.""" @@ -245,9 +423,7 @@ class NotionConnector(LoadConnector, PollConnector): return None - def _read_pages( - self, pages: list[NotionPage] - ) -> Generator[Document, None, None]: + def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]: """Reads pages for rich text content and generates Documents.""" all_child_page_ids: list[str] = [] @@ -255,11 +431,17 @@ class NotionConnector(LoadConnector, PollConnector): if isinstance(page, dict): page = NotionPage(**page) if page.id in self.indexed_pages: - logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.") + logging.debug(f"[Notion]: Already indexed page with ID {page.id}. Skipping.") continue - logging.info(f"Reading page with ID '{page.id}', with url {page.url}") - page_blocks, child_page_ids = self._read_blocks(page.id) + if start is not None and end is not None: + page_ts = datetime_from_string(page.last_edited_time).timestamp() + if not (page_ts > start and page_ts <= end): + logging.debug(f"[Notion]: Skipping page {page.id} outside polling window.") + continue + + logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}") + page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time) all_child_page_ids.extend(child_page_ids) self.indexed_pages.add(page.id) @@ -268,14 +450,12 @@ class NotionConnector(LoadConnector, PollConnector): if not page_blocks: if not raw_page_title: - logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.") + logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.") continue text = page_title if page.properties: - text += "\n\n" + "\n".join( - [f"{key}: {value}" for key, value in page.properties.items()] - ) + text += "\n\n" + "\n".join([f"{key}: {value}" for key, value in page.properties.items()]) sections = [TextSection(link=page.url, text=text)] else: sections = [ @@ -286,45 +466,39 @@ class NotionConnector(LoadConnector, PollConnector): for block in page_blocks ] - blob = ("\n".join([sec.text for sec in sections])).encode("utf-8") + joined_text = "\n".join(sec.text for sec in sections) + blob = joined_text.encode("utf-8") yield Document( - id=page.id, - blob=blob, - source=DocumentSource.NOTION, - semantic_identifier=page_title, - extension=".txt", - size_bytes=len(blob), - doc_updated_at=datetime_from_string(page.last_edited_time) + id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time) ) + for attachment_doc in attachment_docs: + yield attachment_doc + if self.recursive_index_enabled and all_child_page_ids: for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE): - child_page_batch = [ - self._fetch_page(page_id) - for page_id in child_page_batch_ids - if page_id not in self.indexed_pages - ] - yield from self._read_pages(child_page_batch) + child_page_batch = [self._fetch_page(page_id) for page_id in child_page_batch_ids if page_id not in self.indexed_pages] + yield from self._read_pages(child_page_batch, start, end) @retry(tries=3, delay=1, backoff=2) def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse: """Search for pages from a Notion database.""" - logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}") + logging.debug(f"[Notion]: Searching for pages in Notion with query_dict: {query_dict}") data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict) return NotionSearchResponse(**data) - def _recursive_load(self) -> Generator[list[Document], None, None]: + def _recursive_load(self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[list[Document], None, None]: """Recursively load pages starting from root page ID.""" if self.root_page_id is None or not self.recursive_index_enabled: raise RuntimeError("Recursive page lookup is not enabled") - logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}") + logging.info(f"[Notion]: Recursively loading pages from Notion based on root page with ID: {self.root_page_id}") pages = [self._fetch_page(page_id=self.root_page_id)] - yield from batch_generator(self._read_pages(pages), self.batch_size) + yield from batch_generator(self._read_pages(pages, start, end), self.batch_size) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Applies integration token to headers.""" - self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}' + self.headers["Authorization"] = f"Bearer {credentials['notion_integration_token']}" return None def load_from_state(self) -> GenerateDocumentsOutput: @@ -348,12 +522,10 @@ class NotionConnector(LoadConnector, PollConnector): else: break - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll Notion for updated pages within a time period.""" if self.recursive_index_enabled and self.root_page_id: - yield from self._recursive_load() + yield from self._recursive_load(start, end) return query_dict = { @@ -367,7 +539,7 @@ class NotionConnector(LoadConnector, PollConnector): pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time") if pages: - yield from batch_generator(self._read_pages(pages), self.batch_size) + yield from batch_generator(self._read_pages(pages, start, end), self.batch_size) if db_res.has_more: query_dict["start_cursor"] = db_res.next_cursor else: diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py index 59fec9250..965f82265 100644 --- a/deepdoc/parser/docling_parser.py +++ b/deepdoc/parser/docling_parser.py @@ -187,7 +187,7 @@ class DoclingParser(RAGFlowPdfParser): bbox = _BBox(int(pn), bb[0], bb[1], bb[2], bb[3]) yield (DoclingContentType.EQUATION.value, text, bbox) - def _transfer_to_sections(self, doc) -> list[tuple[str, str]]: + def _transfer_to_sections(self, doc, parse_method: str) -> list[tuple[str, str]]: sections: list[tuple[str, str]] = [] for typ, payload, bbox in self._iter_doc_items(doc): if typ == DoclingContentType.TEXT.value: @@ -200,7 +200,12 @@ class DoclingParser(RAGFlowPdfParser): continue tag = self._make_line_tag(bbox) if isinstance(bbox,_BBox) else "" - sections.append((section, tag)) + if parse_method == "manual": + sections.append((section, typ, tag)) + elif parse_method == "paper": + sections.append((section + tag, typ)) + else: + sections.append((section, tag)) return sections def cropout_docling_table(self, page_no: int, bbox: tuple[float, float, float, float], zoomin: int = 1): @@ -282,7 +287,8 @@ class DoclingParser(RAGFlowPdfParser): output_dir: Optional[str] = None, lang: Optional[str] = None, method: str = "auto", - delete_output: bool = True, + delete_output: bool = True, + parse_method: str = "raw" ): if not self.check_installation(): @@ -318,7 +324,7 @@ class DoclingParser(RAGFlowPdfParser): if callback: callback(0.7, f"[Docling] Parsed doc: {getattr(doc, 'num_pages', 'n/a')} pages") - sections = self._transfer_to_sections(doc) + sections = self._transfer_to_sections(doc, parse_method=parse_method) tables = self._transfer_to_tables(doc) if callback: diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index d2b694188..d4834de39 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -476,7 +476,7 @@ class MinerUParser(RAGFlowPdfParser): item[key] = str((subdir / item[key]).resolve()) return data - def _transfer_to_sections(self, outputs: list[dict[str, Any]]): + def _transfer_to_sections(self, outputs: list[dict[str, Any]], parse_method: str = None): sections = [] for output in outputs: match output["type"]: @@ -497,7 +497,11 @@ class MinerUParser(RAGFlowPdfParser): case MinerUContentType.DISCARDED: pass - if section: + if section and parse_method == "manual": + sections.append((section, output["type"], self._line_tag(output))) + elif section and parse_method == "paper": + sections.append((section + self._line_tag(output), output["type"])) + else: sections.append((section, self._line_tag(output))) return sections @@ -516,6 +520,7 @@ class MinerUParser(RAGFlowPdfParser): method: str = "auto", server_url: Optional[str] = None, delete_output: bool = True, + parse_method: str = "raw" ) -> tuple: import shutil @@ -565,7 +570,8 @@ class MinerUParser(RAGFlowPdfParser): self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.") if callback: callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.") - return self._transfer_to_sections(outputs), self._transfer_to_tables(outputs) + + return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs) finally: if temp_pdf and temp_pdf.exists(): try: diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 5bc877a6a..6d8431c82 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -33,6 +33,8 @@ import xgboost as xgb from huggingface_hub import snapshot_download from PIL import Image from pypdf import PdfReader as pdf2_read +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score from common.file_utils import get_project_base_directory from common.misc_utils import pip_install_torch @@ -353,7 +355,6 @@ class RAGFlowPdfParser: def _assign_column(self, boxes, zoomin=3): if not boxes: return boxes - if all("col_id" in b for b in boxes): return boxes @@ -361,61 +362,80 @@ class RAGFlowPdfParser: for b in boxes: by_page[b["page_number"]].append(b) - page_info = {} # pg -> dict(page_w, left_edge, cand_cols) - counter = Counter() + page_cols = {} for pg, bxs in by_page.items(): if not bxs: - page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1} - counter[1] += 1 + page_cols[pg] = 1 continue - if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg: - page_w = self.page_images[pg - 1].size[0] / max(1, zoomin) - left_edge = 0.0 - else: - xs0 = [box["x0"] for box in bxs] - xs1 = [box["x1"] for box in bxs] - left_edge = float(min(xs0)) - page_w = max(1.0, float(max(xs1) - left_edge)) + x0s_raw = np.array([b["x0"] for b in bxs], dtype=float) - widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs] - median_w = float(np.median(widths)) if widths else 1.0 + min_x0 = np.min(x0s_raw) + max_x1 = np.max([b["x1"] for b in bxs]) + width = max_x1 - min_x0 - raw_cols = int(page_w / max(1.0, median_w)) + INDENT_TOL = width * 0.12 + x0s = [] + for x in x0s_raw: + if abs(x - min_x0) < INDENT_TOL: + x0s.append([min_x0]) + else: + x0s.append([x]) + x0s = np.array(x0s, dtype=float) + + max_try = min(4, len(bxs)) + if max_try < 2: + max_try = 1 + best_k = 1 + best_score = -1 - # cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1 - cand = raw_cols + for k in range(1, max_try + 1): + km = KMeans(n_clusters=k, n_init="auto") + labels = km.fit_predict(x0s) - page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand} - counter[cand] += 1 + centers = np.sort(km.cluster_centers_.flatten()) + if len(centers) > 1: + try: + score = silhouette_score(x0s, labels) + except ValueError: + continue + else: + score = 0 + print(f"{k=},{score=}",flush=True) + if score > best_score: + best_score = score + best_k = k - logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}") + page_cols[pg] = best_k + logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}") - global_cols = counter.most_common(1)[0][0] + + global_cols = Counter(page_cols.values()).most_common(1)[0][0] logging.info(f"Global column_num decided by majority: {global_cols}") + for pg, bxs in by_page.items(): if not bxs: continue + k = page_cols[pg] + if len(bxs) < k: + k = 1 + x0s = np.array([[b["x0"]] for b in bxs], dtype=float) + km = KMeans(n_clusters=k, n_init="auto") + labels = km.fit_predict(x0s) - page_w = page_info[pg]["page_w"] - left_edge = page_info[pg]["left_edge"] + centers = km.cluster_centers_.flatten() + order = np.argsort(centers) - if global_cols == 1: - for box in bxs: - box["col_id"] = 0 - continue + remap = {orig: new for new, orig in enumerate(order)} - for box in bxs: - w = box["x1"] - box["x0"] - if w >= 0.8 * page_w: - box["col_id"] = 0 - continue - cx = 0.5 * (box["x0"] + box["x1"]) - norm_cx = (cx - left_edge) / page_w - norm_cx = max(0.0, min(norm_cx, 0.999999)) - box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols)) + for b, lb in zip(bxs, labels): + b["col_id"] = remap[lb] + + grouped = defaultdict(list) + for b in bxs: + grouped[b["col_id"]].append(b) return boxes @@ -1303,7 +1323,10 @@ class RAGFlowPdfParser: positions = [] for ii, (pns, left, right, top, bottom) in enumerate(poss): - right = left + max_width + if 0 < ii < len(poss) - 1: + right = max(left + 10, right) + else: + right = left + max_width bottom *= ZM for pn in pns[1:]: if 0 <= pn - 1 < page_count: diff --git a/docker/.env b/docker/.env index d7e4b025f..6423b7824 100644 --- a/docker/.env +++ b/docker/.env @@ -230,9 +230,16 @@ REGISTER_ENABLED=1 # SANDBOX_MAX_MEMORY=256m # b, k, m, g # SANDBOX_TIMEOUT=10s # s, m, 1m30s -# Enable DocLing and Mineru +# Enable DocLing USE_DOCLING=false + +# Enable Mineru USE_MINERU=false +MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru" +MINERU_DELETE_OUTPUT=0 # keep output directory +MINERU_BACKEND=pipeline # or another backend you prefer + + # pptx support DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 \ No newline at end of file diff --git a/rag/app/manual.py b/rag/app/manual.py index 4f9de40c7..b3a4ae38d 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -214,6 +214,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback = callback, pdf_cls = Pdf, layout_recognizer = layout_recognizer, + parse_method = "manual", **kwargs ) @@ -226,7 +227,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, elif len(section) != 3: raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})") - txt, sec_id, poss = section + txt, layoutno, poss = section if isinstance(poss, str): poss = pdf_parser.extract_positions(poss) first = poss[0] # tuple: ([pn], x1, x2, y1, y2) @@ -236,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, pn = pn[0] # [pn] -> pn poss[0] = (pn, *first[1:]) - return (txt, sec_id, poss) + return (txt, layoutno, poss) sections = [_normalize_section(sec) for sec in sections] diff --git a/rag/app/naive.py b/rag/app/naive.py index 49dca17af..562336d7f 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -59,6 +59,7 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese" mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru") mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987") pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api) + parse_method = kwargs.get("parse_method", "raw") if not pdf_parser.check_installation(): callback(-1, "MinerU not found.") @@ -72,12 +73,14 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese" backend=os.environ.get("MINERU_BACKEND", "pipeline"), server_url=os.environ.get("MINERU_SERVER_URL", ""), delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))), + parse_method=parse_method ) return sections, tables, pdf_parser def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs): pdf_parser = DoclingParser() + parse_method = kwargs.get("parse_method", "raw") if not pdf_parser.check_installation(): callback(-1, "Docling not found.") @@ -89,6 +92,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese callback=callback, output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))), + parse_method=parse_method ) return sections, tables, pdf_parser diff --git a/rag/app/paper.py b/rag/app/paper.py index d95976c9f..222be0762 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -21,8 +21,10 @@ import re from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper from common.constants import ParserType from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks -from deepdoc.parser import PdfParser, PlainParser +from deepdoc.parser import PdfParser import numpy as np +from rag.app.naive import by_plaintext, PARSERS + class Pdf(PdfParser): def __init__(self): @@ -147,19 +149,40 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, "parser_config", { "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) if re.search(r"\.pdf$", filename, re.IGNORECASE): - if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": - pdf_parser = PlainParser() + layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") + + if isinstance(layout_recognizer, bool): + layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" + + name = layout_recognizer.strip().lower() + pdf_parser = PARSERS.get(name, by_plaintext) + callback(0.1, "Start to parse.") + + if name == "deepdoc": + pdf_parser = Pdf() + paper = pdf_parser(filename if not binary else binary, + from_page=from_page, to_page=to_page, callback=callback) + else: + sections, tables, pdf_parser = pdf_parser( + filename=filename, + binary=binary, + from_page=from_page, + to_page=to_page, + lang=lang, + callback=callback, + pdf_cls=Pdf, + parse_method="paper", + **kwargs + ) + paper = { "title": filename, "authors": " ", "abstract": "", - "sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0], - "tables": [] + "sections": sections, + "tables": tables } - else: - pdf_parser = Pdf() - paper = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) + tbls=paper["tables"] tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) paper["tables"] = tbls diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 4dbd9945c..a479e5d3f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -355,75 +355,102 @@ class Dealer: rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(inst).split()) - def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2, - vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, - rerank_mdl=None, highlight=False, - rank_feature: dict | None = {PAGERANK_FLD: 10}): + def retrieval( + self, + question, + embd_mdl, + tenant_ids, + kb_ids, + page, + page_size, + similarity_threshold=0.2, + vector_similarity_weight=0.3, + top=1024, + doc_ids=None, + aggs=True, + rerank_mdl=None, + highlight=False, + rank_feature: dict | None = {PAGERANK_FLD: 10}, + ): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks # Ensure RERANK_LIMIT is multiple of page_size - RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1 - req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT, - "question": question, "vector": True, "topk": top, - "similarity": similarity_threshold, - "available_int": 1} - + RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1 + req = { + "kb_ids": kb_ids, + "doc_ids": doc_ids, + "page": math.ceil(page_size * page / RERANK_LIMIT), + "size": RERANK_LIMIT, + "question": question, + "vector": True, + "topk": top, + "similarity": similarity_threshold, + "available_int": 1, + } if isinstance(tenant_ids, str): tenant_ids = tenant_ids.split(",") - sres = self.search(req, [index_name(tid) for tid in tenant_ids], - kb_ids, embd_mdl, highlight, rank_feature=rank_feature) + sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) if rerank_mdl and sres.total > 0: - sim, tsim, vsim = self.rerank_by_model(rerank_mdl, - sres, question, 1 - vector_similarity_weight, - vector_similarity_weight, - rank_feature=rank_feature) + sim, tsim, vsim = self.rerank_by_model( + rerank_mdl, + sres, + question, + 1 - vector_similarity_weight, + vector_similarity_weight, + rank_feature=rank_feature, + ) else: - lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') - if lower_case_doc_engine in ["elasticsearch","opensearch"]: + lower_case_doc_engine = os.getenv("DOC_ENGINE", "elasticsearch") + if lower_case_doc_engine in ["elasticsearch", "opensearch"]: # ElasticSearch doesn't normalize each way score before fusion. sim, tsim, vsim = self.rerank( - sres, question, 1 - vector_similarity_weight, vector_similarity_weight, - rank_feature=rank_feature) + sres, + question, + 1 - vector_similarity_weight, + vector_similarity_weight, + rank_feature=rank_feature, + ) else: # Don't need rerank here since Infinity normalizes each way score before fusion. sim = [sres.field[id].get("_score", 0.0) for id in sres.ids] - sim = [s if s is not None else 0. for s in sim] + sim = [s if s is not None else 0.0 for s in sim] tsim = sim vsim = sim - # Already paginated in search function - max_pages = RERANK_LIMIT // page_size - page_index = (page % max_pages) - 1 - begin = max(page_index * page_size, 0) - sim = sim[begin : begin + page_size] + sim_np = np.array(sim, dtype=np.float64) - idx = np.argsort(sim_np * -1) + if sim_np.size == 0: + return ranks + + sorted_idx = np.argsort(sim_np * -1) + + valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= similarity_threshold] + filtered_count = len(valid_idx) + ranks["total"] = int(filtered_count) + + if filtered_count == 0: + return ranks + + max_pages = max(RERANK_LIMIT // max(page_size, 1), 1) + page_index = (page - 1) % max_pages + begin = page_index * page_size + end = begin + page_size + page_idx = valid_idx[begin:end] + dim = len(sres.query_vector) vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim - filtered_count = (sim_np >= similarity_threshold).sum() - ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error - for i in idx: - if np.float64(sim[i]) < similarity_threshold: - break + for i in page_idx: id = sres.ids[i] chunk = sres.field[id] dnm = chunk.get("docnm_kwd", "") did = chunk.get("doc_id", "") - if len(ranks["chunks"]) >= page_size: - if aggs: - if dnm not in ranks["doc_aggs"]: - ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} - ranks["doc_aggs"][dnm]["count"] += 1 - continue - break - position_int = chunk.get("position_int", []) d = { "chunk_id": id, @@ -434,12 +461,12 @@ class Dealer: "kb_id": chunk["kb_id"], "important_kwd": chunk.get("important_kwd", []), "image_id": chunk.get("img_id", ""), - "similarity": sim[i], - "vector_similarity": vsim[i], - "term_similarity": tsim[i], + "similarity": float(sim_np[i]), + "vector_similarity": float(vsim[i]), + "term_similarity": float(tsim[i]), "vector": chunk.get(vector_column, zero_vector), "positions": position_int, - "doc_type_kwd": chunk.get("doc_type_kwd", "") + "doc_type_kwd": chunk.get("doc_type_kwd", ""), } if highlight and sres.highlight: if id in sres.highlight: @@ -447,15 +474,30 @@ class Dealer: else: d["highlight"] = d["content_with_weight"] ranks["chunks"].append(d) - if dnm not in ranks["doc_aggs"]: - ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} - ranks["doc_aggs"][dnm]["count"] += 1 - ranks["doc_aggs"] = [{"doc_name": k, - "doc_id": v["doc_id"], - "count": v["count"]} for k, - v in sorted(ranks["doc_aggs"].items(), - key=lambda x: x[1]["count"] * -1)] - ranks["chunks"] = ranks["chunks"][:page_size] + + if aggs: + for i in valid_idx: + id = sres.ids[i] + chunk = sres.field[id] + dnm = chunk.get("docnm_kwd", "") + did = chunk.get("doc_id", "") + if dnm not in ranks["doc_aggs"]: + ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} + ranks["doc_aggs"][dnm]["count"] += 1 + + ranks["doc_aggs"] = [ + { + "doc_name": k, + "doc_id": v["doc_id"], + "count": v["count"], + } + for k, v in sorted( + ranks["doc_aggs"].items(), + key=lambda x: x[1]["count"] * -1, + ) + ] + else: + ranks["doc_aggs"] = [] return ranks @@ -564,7 +606,7 @@ class Dealer: ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) if not ids: return chunks - + vector_size = 1024 id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)} for cid, sim in ids: diff --git a/web/package-lock.json b/web/package-lock.json index e8d87301b..7bae0ed0f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -66,6 +66,7 @@ "input-otp": "^1.4.1", "js-base64": "^3.7.5", "jsencrypt": "^3.3.2", + "jsoneditor": "^10.4.2", "lexical": "^0.23.1", "lodash": "^4.17.21", "lucide-react": "^0.546.0", @@ -8998,6 +8999,12 @@ "@sinonjs/commons": "^3.0.0" } }, + "node_modules/@sphinxxxx/color-conversion": { + "version": "2.2.2", + "resolved": "https://registry.npmmirror.com/@sphinxxxx/color-conversion/-/color-conversion-2.2.2.tgz", + "integrity": "sha512-XExJS3cLqgrmNBIP3bBw6+1oQ1ksGjFh0+oClDKFYpCCqx/hlqwWO5KO/S63fzUo67SxI9dMrF0y5T/Ey7h8Zw==", + "license": "ISC" + }, "node_modules/@storybook/addon-docs": { "version": "9.1.4", "resolved": "https://registry.npmmirror.com/@storybook/addon-docs/-/addon-docs-9.1.4.tgz", @@ -12962,6 +12969,12 @@ "node": ">= 0.6" } }, + "node_modules/ace-builds": { + "version": "1.43.4", + "resolved": "https://registry.npmmirror.com/ace-builds/-/ace-builds-1.43.4.tgz", + "integrity": "sha512-8hAxVfo2ImICd69BWlZwZlxe9rxDGDjuUhh+WeWgGDvfBCE+r3lkynkQvIovDz4jcMi8O7bsEaFygaDT+h9sBA==", + "license": "BSD-3-Clause" + }, "node_modules/acorn": { "version": "8.15.0", "resolved": "https://registry.npmmirror.com/acorn/-/acorn-8.15.0.tgz", @@ -21894,6 +21907,12 @@ "@pkgjs/parseargs": "^0.11.0" } }, + "node_modules/javascript-natural-sort": { + "version": "0.7.1", + "resolved": "https://registry.npmmirror.com/javascript-natural-sort/-/javascript-natural-sort-0.7.1.tgz", + "integrity": "sha512-nO6jcEfZWQXDhOiBtG2KvKyEptz7RVbpGP4vTD2hLBdmNQSsCiicO2Ioinv6UI4y9ukqnBpy+XZ9H6uLNgJTlw==", + "license": "MIT" + }, "node_modules/javascript-stringify": { "version": "2.1.0", "resolved": "https://registry.npmmirror.com/javascript-stringify/-/javascript-stringify-2.1.0.tgz", @@ -24253,6 +24272,15 @@ "jiti": "bin/jiti.js" } }, + "node_modules/jmespath": { + "version": "0.16.0", + "resolved": "https://registry.npmmirror.com/jmespath/-/jmespath-0.16.0.tgz", + "integrity": "sha512-9FzQjJ7MATs1tSpnco1K6ayiYE3figslrXA72G2HQ/n76RzvYlofyi5QM+iX4YRs/pu3yzxlVQSST23+dMDknw==", + "license": "Apache-2.0", + "engines": { + "node": ">= 0.6.0" + } + }, "node_modules/js-base64": { "version": "3.7.5", "resolved": "https://registry.npmmirror.com/js-base64/-/js-base64-3.7.5.tgz", @@ -24357,6 +24385,12 @@ "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", "license": "MIT" }, + "node_modules/json-source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmmirror.com/json-source-map/-/json-source-map-0.6.1.tgz", + "integrity": "sha512-1QoztHPsMQqhDq0hlXY5ZqcEdUzxQEIxgFkKl4WUp2pgShObl+9ovi4kRh2TfvAfxAoHOJ9vIMEqk3k4iex7tg==", + "license": "MIT" + }, "node_modules/json-stable-stringify-without-jsonify": { "version": "1.0.1", "resolved": "https://registry.npmmirror.com/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", @@ -24393,6 +24427,44 @@ "node": ">=6" } }, + "node_modules/jsoneditor": { + "version": "10.4.2", + "resolved": "https://registry.npmmirror.com/jsoneditor/-/jsoneditor-10.4.2.tgz", + "integrity": "sha512-SQPCXlanU4PqdVsYuj2X7yfbLiiJYjklbksGfMKPsuwLhAIPxDlG43jYfXieGXvxpuq1fkw08YoRbkKXKabcLA==", + "license": "Apache-2.0", + "dependencies": { + "ace-builds": "^1.36.2", + "ajv": "^6.12.6", + "javascript-natural-sort": "^0.7.1", + "jmespath": "^0.16.0", + "json-source-map": "^0.6.1", + "jsonrepair": "^3.8.1", + "picomodal": "^3.0.0", + "vanilla-picker": "^2.12.3" + } + }, + "node_modules/jsoneditor/node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmmirror.com/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/jsoneditor/node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmmirror.com/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "license": "MIT" + }, "node_modules/jsonfile": { "version": "6.1.0", "resolved": "https://registry.npmmirror.com/jsonfile/-/jsonfile-6.1.0.tgz", @@ -24404,6 +24476,15 @@ "graceful-fs": "^4.1.6" } }, + "node_modules/jsonrepair": { + "version": "3.13.1", + "resolved": "https://registry.npmmirror.com/jsonrepair/-/jsonrepair-3.13.1.tgz", + "integrity": "sha512-WJeiE0jGfxYmtLwBTEk8+y/mYcaleyLXWaqp5bJu0/ZTSeG0KQq/wWQ8pmnkKenEdN6pdnn6QtcoSUkbqDHWNw==", + "license": "ISC", + "bin": { + "jsonrepair": "bin/cli.js" + } + }, "node_modules/jsx-ast-utils": { "version": "3.3.5", "resolved": "https://registry.npmmirror.com/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz", @@ -27499,6 +27580,12 @@ "node": ">=8.6" } }, + "node_modules/picomodal": { + "version": "3.0.0", + "resolved": "https://registry.npmmirror.com/picomodal/-/picomodal-3.0.0.tgz", + "integrity": "sha512-FoR3TDfuLlqUvcEeK5ifpKSVVns6B4BQvc8SDF6THVMuadya6LLtji0QgUDSStw0ZR2J7I6UGi5V2V23rnPWTw==", + "license": "MIT" + }, "node_modules/pidtree": { "version": "0.6.0", "resolved": "https://registry.npmmirror.com/pidtree/-/pidtree-0.6.0.tgz", @@ -36235,6 +36322,15 @@ "dev": true, "peer": true }, + "node_modules/vanilla-picker": { + "version": "2.12.3", + "resolved": "https://registry.npmmirror.com/vanilla-picker/-/vanilla-picker-2.12.3.tgz", + "integrity": "sha512-qVkT1E7yMbUsB2mmJNFmaXMWE2hF8ffqzMMwe9zdAikd8u2VfnsVY2HQcOUi2F38bgbxzlJBEdS1UUhOXdF9GQ==", + "license": "ISC", + "dependencies": { + "@sphinxxxx/color-conversion": "^2.2.2" + } + }, "node_modules/vary": { "version": "1.1.2", "resolved": "https://registry.npmmirror.com/vary/-/vary-1.1.2.tgz", diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx index 2b23010cd..38c49b666 100644 --- a/web/src/pages/agent/form/agent-form/index.tsx +++ b/web/src/pages/agent/form/agent-form/index.tsx @@ -22,7 +22,8 @@ import { Switch } from '@/components/ui/switch'; import { LlmModelType } from '@/constants/knowledge'; import { useFindLlmByUuid } from '@/hooks/use-llm-request'; import { zodResolver } from '@hookform/resolvers/zod'; -import { memo, useCallback, useEffect, useMemo } from 'react'; +import { get } from 'lodash'; +import { memo, useEffect, useMemo } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; @@ -45,7 +46,10 @@ import { AgentTools, Agents } from './agent-tools'; import { StructuredOutputDialog } from './structured-output-dialog'; import { StructuredOutputPanel } from './structured-output-panel'; import { useBuildPromptExtraPromptOptions } from './use-build-prompt-options'; -import { useShowStructuredOutputDialog } from './use-show-structured-output-dialog'; +import { + useHandleShowStructuredOutput, + useShowStructuredOutputDialog, +} from './use-show-structured-output-dialog'; import { useValues } from './use-values'; import { useWatchFormChange } from './use-watch-change'; @@ -121,22 +125,19 @@ function AgentForm({ node }: INextOperatorForm) { }); const { - initialStructuredOutput, showStructuredOutputDialog, structuredOutputDialogVisible, hideStructuredOutputDialog, handleStructuredOutputDialogOk, } = useShowStructuredOutputDialog(node?.id); - const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + const structuredOutput = get( + node, + `data.form.outputs.${AgentStructuredOutputField}`, + ); - const handleShowStructuredOutput = useCallback( - (val: boolean) => { - if (node?.id && val) { - updateNodeForm(node?.id, {}, ['outputs', AgentStructuredOutputField]); - } - }, - [node?.id, updateNodeForm], + const { handleShowStructuredOutput } = useHandleShowStructuredOutput( + node?.id, ); useEffect(() => { @@ -327,7 +328,7 @@ function AgentForm({ node }: INextOperatorForm) { )} @@ -337,7 +338,7 @@ function AgentForm({ node }: INextOperatorForm) { )} diff --git a/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts b/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts index 19e38cefe..d66fcfb45 100644 --- a/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts +++ b/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts @@ -1,6 +1,8 @@ import { JSONSchema } from '@/components/jsonjoy-builder'; +import { AgentStructuredOutputField } from '@/constants/agent'; import { useSetModalState } from '@/hooks/common-hooks'; import { useCallback } from 'react'; +import { initialAgentValues } from '../../constant'; import useGraphStore from '../../store'; export function useShowStructuredOutputDialog(nodeId?: string) { @@ -9,15 +11,13 @@ export function useShowStructuredOutputDialog(nodeId?: string) { showModal: showStructuredOutputDialog, hideModal: hideStructuredOutputDialog, } = useSetModalState(); - const { updateNodeForm, getNode } = useGraphStore((state) => state); - - const initialStructuredOutput = getNode(nodeId)?.data.form.outputs.structured; + const { updateNodeForm } = useGraphStore((state) => state); const handleStructuredOutputDialogOk = useCallback( (values: JSONSchema) => { // Sync data to canvas if (nodeId) { - updateNodeForm(nodeId, values, ['outputs', 'structured']); + updateNodeForm(nodeId, values, ['outputs', AgentStructuredOutputField]); } hideStructuredOutputDialog(); }, @@ -25,10 +25,30 @@ export function useShowStructuredOutputDialog(nodeId?: string) { ); return { - initialStructuredOutput, structuredOutputDialogVisible, showStructuredOutputDialog, hideStructuredOutputDialog, handleStructuredOutputDialogOk, }; } + +export function useHandleShowStructuredOutput(nodeId?: string) { + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + const handleShowStructuredOutput = useCallback( + (val: boolean) => { + if (nodeId) { + if (val) { + updateNodeForm(nodeId, {}, ['outputs', AgentStructuredOutputField]); + } else { + updateNodeForm(nodeId, initialAgentValues.outputs, ['outputs']); + } + } + }, + [nodeId, updateNodeForm], + ); + + return { + handleShowStructuredOutput, + }; +} diff --git a/web/src/pages/agent/form/agent-form/use-values.ts b/web/src/pages/agent/form/agent-form/use-values.ts index f8747e4b4..fb7f94861 100644 --- a/web/src/pages/agent/form/agent-form/use-values.ts +++ b/web/src/pages/agent/form/agent-form/use-values.ts @@ -6,8 +6,10 @@ import { initialAgentValues } from '../../constant'; // You need to exclude the mcp and tools fields that are not in the form, // otherwise the form data update will reset the tools or mcp data to an array +// Exclude data that is not in the form to avoid writing this data to the canvas when using useWatch. +// Outputs, tools, and MCP data are directly synchronized to the canvas without going through the form. function omitToolsAndMcp(values: Record) { - return omit(values, ['mcp', 'tools']); + return omit(values, ['mcp', 'tools', 'outputs']); } export function useValues(node?: RAGFlowNodeType) { diff --git a/web/src/pages/agent/form/agent-form/use-watch-change.ts b/web/src/pages/agent/form/agent-form/use-watch-change.ts index 7c53a8d40..98b0ecf31 100644 --- a/web/src/pages/agent/form/agent-form/use-watch-change.ts +++ b/web/src/pages/agent/form/agent-form/use-watch-change.ts @@ -1,7 +1,6 @@ -import { omit } from 'lodash'; import { useEffect } from 'react'; import { UseFormReturn, useWatch } from 'react-hook-form'; -import { AgentStructuredOutputField, PromptRole } from '../../constant'; +import { PromptRole } from '../../constant'; import useGraphStore from '../../store'; export function useWatchFormChange(id?: string, form?: UseFormReturn) { @@ -17,14 +16,6 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn) { prompts: [{ role: PromptRole.User, content: values.prompts }], }; - if (!values.showStructuredOutput) { - nextValues = { - ...nextValues, - outputs: omit(values.outputs, [AgentStructuredOutputField]), - }; - } else { - nextValues = omit(nextValues, 'outputs'); - } updateNodeForm(id, nextValues); } }, [form?.formState.isDirty, id, updateNodeForm, values]);