diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py
index c3a01e517..fd9096cb6 100644
--- a/agent/tools/retrieval.py
+++ b/agent/tools/retrieval.py
@@ -137,7 +137,7 @@ class Retrieval(ToolBase, ABC):
if not doc_ids:
doc_ids = None
elif self._param.meta_data_filter.get("method") == "manual":
- filters=self._param.meta_data_filter["manual"]
+ filters = self._param.meta_data_filter["manual"]
for flt in filters:
pat = re.compile(self.variable_ref_patt)
s = flt["value"]
@@ -166,8 +166,8 @@ class Retrieval(ToolBase, ABC):
out_parts.append(s[last:])
flt["value"] = "".join(out_parts)
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
- if not doc_ids:
- doc_ids = None
+ if filters and not doc_ids:
+ doc_ids = ["-999"]
if self._param.cross_languages:
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
index e121bcba7..b43fb9af1 100644
--- a/api/apps/chunk_app.py
+++ b/api/apps/chunk_app.py
@@ -311,8 +311,8 @@ async def retrieval_test():
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
- if not doc_ids:
- doc_ids = None
+ if meta_data_filter["manual"] and not doc_ids:
+ doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=current_user.id)
diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py
index 84300ac3c..30fbd835e 100644
--- a/api/apps/sdk/doc.py
+++ b/api/apps/sdk/doc.py
@@ -1445,6 +1445,8 @@ async def retrieval_test(tenant_id):
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
+ if metadata_condition and not doc_ids:
+ doc_ids = ["-999"]
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index 533375622..074401ede 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -446,8 +446,8 @@ async def agent_completions(tenant_id, agent_id):
if req.get("stream", True):
- def generate():
- for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
+ async def generate():
+ async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
if isinstance(answer, str):
try:
ans = json.loads(answer[5:]) # remove "data:"
@@ -471,7 +471,7 @@ async def agent_completions(tenant_id, agent_id):
full_content = ""
reference = {}
final_ans = ""
- for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
+ async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
try:
ans = json.loads(answer[5:])
@@ -873,7 +873,7 @@ async def agent_bot_completions(agent_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
- for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
+ async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
return get_result(data=answer)
@@ -981,8 +981,8 @@ async def retrieval_test_embedded():
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
- if not doc_ids:
- doc_ids = None
+ if meta_data_filter["manual"] and not doc_ids:
+ doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id)
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index db878574d..0a09ea532 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -415,9 +415,10 @@ def chat(dialog, messages, stream=True, **kwargs):
if not attachments:
attachments = None
elif dialog.meta_data_filter.get("method") == "manual":
- attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"], dialog.meta_data_filter.get("logic", "and")))
- if not attachments:
- attachments = None
+ conds = dialog.meta_data_filter["manual"]
+ attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
+ if conds and not attachments:
+ attachments = ["-999"]
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
@@ -787,8 +788,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
- if not doc_ids:
- doc_ids = None
+ if meta_data_filter["manual"] and not doc_ids:
+ doc_ids = ["-999"]
kbinfos = retriever.retrieval(
question=question,
@@ -862,8 +863,8 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
- if not doc_ids:
- doc_ids = None
+ if meta_data_filter["manual"] and not doc_ids:
+ doc_ids = ["-999"]
ranks = settings.retriever.retrieval(
question=question,
diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py
index 8c6a522ad..e29bbbe76 100644
--- a/common/data_source/notion_connector.py
+++ b/common/data_source/notion_connector.py
@@ -1,38 +1,45 @@
+import html
import logging
from collections.abc import Generator
+from datetime import datetime, timezone
+from pathlib import Path
from typing import Any, Optional
+from urllib.parse import urlparse
+
from retry import retry
from common.data_source.config import (
INDEX_BATCH_SIZE,
- DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
+ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
+ DocumentSource,
+)
+from common.data_source.exceptions import (
+ ConnectorMissingCredentialError,
+ ConnectorValidationError,
+ CredentialExpiredError,
+ InsufficientPermissionsError,
+ UnexpectedValidationError,
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
- SecondsSinceUnixEpoch
+ SecondsSinceUnixEpoch,
)
from common.data_source.models import (
Document,
- TextSection, GenerateDocumentsOutput
-)
-from common.data_source.exceptions import (
- ConnectorValidationError,
- CredentialExpiredError,
- InsufficientPermissionsError,
- UnexpectedValidationError, ConnectorMissingCredentialError
-)
-from common.data_source.models import (
- NotionPage,
+ GenerateDocumentsOutput,
NotionBlock,
- NotionSearchResponse
+ NotionPage,
+ NotionSearchResponse,
+ TextSection,
)
from common.data_source.utils import (
- rl_requests,
batch_generator,
+ datetime_from_string,
fetch_notion_data,
+ filter_pages_by_time,
properties_to_str,
- filter_pages_by_time, datetime_from_string
+ rl_requests,
)
@@ -61,11 +68,9 @@ class NotionConnector(LoadConnector, PollConnector):
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
@retry(tries=3, delay=1, backoff=2)
- def _fetch_child_blocks(
- self, block_id: str, cursor: Optional[str] = None
- ) -> dict[str, Any] | None:
+ def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
"""Fetch all child blocks via the Notion API."""
- logging.debug(f"Fetching children of block with ID '{block_id}'")
+ logging.debug(f"[Notion]: Fetching children of block with ID {block_id}")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = {"start_cursor": cursor} if cursor else None
@@ -79,49 +84,42 @@ class NotionConnector(LoadConnector, PollConnector):
response.raise_for_status()
return response.json()
except Exception as e:
- if hasattr(e, 'response') and e.response.status_code == 404:
- logging.error(
- f"Unable to access block with ID '{block_id}'. "
- f"This is likely due to the block not being shared with the integration."
- )
+ if hasattr(e, "response") and e.response.status_code == 404:
+ logging.error(f"[Notion]: Unable to access block with ID {block_id}. This is likely due to the block not being shared with the integration.")
return None
else:
- logging.exception(f"Error fetching blocks: {e}")
+ logging.exception(f"[Notion]: Error fetching blocks: {e}")
raise
@retry(tries=3, delay=1, backoff=2)
def _fetch_page(self, page_id: str) -> NotionPage:
"""Fetch a page from its ID via the Notion API."""
- logging.debug(f"Fetching page for ID '{page_id}'")
+ logging.debug(f"[Notion]: Fetching page for ID {page_id}")
page_url = f"https://api.notion.com/v1/pages/{page_id}"
try:
data = fetch_notion_data(page_url, self.headers, "GET")
return NotionPage(**data)
except Exception as e:
- logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
+ logging.warning(f"[Notion]: Failed to fetch page, trying database for ID {page_id}: {e}")
return self._fetch_database_as_page(page_id)
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page."""
- logging.debug(f"Fetching database for ID '{database_id}' as a page")
+ logging.debug(f"[Notion]: Fetching database for ID {database_id} as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
data = fetch_notion_data(database_url, self.headers, "GET")
database_name = data.get("title")
- database_name = (
- database_name[0].get("text", {}).get("content") if database_name else None
- )
+ database_name = database_name[0].get("text", {}).get("content") if database_name else None
return NotionPage(**data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
- def _fetch_database(
- self, database_id: str, cursor: Optional[str] = None
- ) -> dict[str, Any]:
+ def _fetch_database(self, database_id: str, cursor: Optional[str] = None) -> dict[str, Any]:
"""Fetch a database from its ID via the Notion API."""
- logging.debug(f"Fetching database for ID '{database_id}'")
+ logging.debug(f"[Notion]: Fetching database for ID {database_id}")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = {"start_cursor": cursor} if cursor else None
@@ -129,17 +127,12 @@ class NotionConnector(LoadConnector, PollConnector):
data = fetch_notion_data(block_url, self.headers, "POST", body)
return data
except Exception as e:
- if hasattr(e, 'response') and e.response.status_code in [404, 400]:
- logging.error(
- f"Unable to access database with ID '{database_id}'. "
- f"This is likely due to the database not being shared with the integration."
- )
+ if hasattr(e, "response") and e.response.status_code in [404, 400]:
+ logging.error(f"[Notion]: Unable to access database with ID {database_id}. This is likely due to the database not being shared with the integration.")
return {"results": [], "next_cursor": None}
raise
- def _read_pages_from_database(
- self, database_id: str
- ) -> tuple[list[NotionBlock], list[str]]:
+ def _read_pages_from_database(self, database_id: str) -> tuple[list[NotionBlock], list[str]]:
"""Returns a list of top level blocks and all page IDs in the database."""
result_blocks: list[NotionBlock] = []
result_pages: list[str] = []
@@ -158,10 +151,10 @@ class NotionConnector(LoadConnector, PollConnector):
if self.recursive_index_enabled:
if obj_type == "page":
- logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
+ logging.debug(f"[Notion]: Found page with ID {obj_id} in database {database_id}")
result_pages.append(result["id"])
elif obj_type == "database":
- logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
+ logging.debug(f"[Notion]: Found database with ID {obj_id} in database {database_id}")
_, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages)
@@ -172,44 +165,229 @@ class NotionConnector(LoadConnector, PollConnector):
return result_blocks, result_pages
- def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
- """Reads all child blocks for the specified block, returns blocks and child page ids."""
+ def _extract_rich_text(self, rich_text_array: list[dict[str, Any]]) -> str:
+ collected_text: list[str] = []
+ for rich_text in rich_text_array:
+ content = ""
+ r_type = rich_text.get("type")
+
+ if r_type == "equation":
+ expr = rich_text.get("equation", {}).get("expression")
+ if expr:
+ content = expr
+ elif r_type == "mention":
+ mention = rich_text.get("mention", {}) or {}
+ mention_type = mention.get("type")
+ mention_value = mention.get(mention_type, {}) if mention_type else {}
+ if mention_type == "date":
+ start = mention_value.get("start")
+ end = mention_value.get("end")
+ if start and end:
+ content = f"{start} - {end}"
+ elif start:
+ content = start
+ elif mention_type in {"page", "database"}:
+ content = mention_value.get("id", rich_text.get("plain_text", ""))
+ elif mention_type == "link_preview":
+ content = mention_value.get("url", rich_text.get("plain_text", ""))
+ else:
+ content = rich_text.get("plain_text", "") or str(mention_value)
+ else:
+ if rich_text.get("plain_text"):
+ content = rich_text["plain_text"]
+ elif "text" in rich_text and rich_text["text"].get("content"):
+ content = rich_text["text"]["content"]
+
+ href = rich_text.get("href")
+ if content and href:
+ content = f"{content} ({href})"
+
+ if content:
+ collected_text.append(content)
+
+ return "".join(collected_text).strip()
+
+ def _build_table_html(self, table_block_id: str) -> str | None:
+ rows: list[str] = []
+ cursor = None
+ while True:
+ data = self._fetch_child_blocks(table_block_id, cursor)
+ if data is None:
+ break
+
+ for result in data["results"]:
+ if result.get("type") != "table_row":
+ continue
+ cells_html: list[str] = []
+ for cell in result["table_row"].get("cells", []):
+ cell_text = self._extract_rich_text(cell)
+ cell_html = html.escape(cell_text) if cell_text else ""
+ cells_html.append(f"
{cell_html} | ")
+ rows.append(f"{''.join(cells_html)}
")
+
+ if data.get("next_cursor") is None:
+ break
+ cursor = data["next_cursor"]
+
+ if not rows:
+ return None
+ return "\n" + "\n".join(rows) + "\n
"
+
+ def _download_file(self, url: str) -> bytes | None:
+ try:
+ response = rl_requests.get(url, timeout=60)
+ response.raise_for_status()
+ return response.content
+ except Exception as exc:
+ logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
+ return None
+
+ def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
+ file_source_type = result_obj.get("type")
+ file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
+ url = file_source.get("url")
+
+ name = result_obj.get("name") or file_source.get("name")
+ if url and not name:
+ parsed_name = Path(urlparse(url).path).name
+ name = parsed_name or f"notion_file_{block_id}"
+ elif not name:
+ name = f"notion_file_{block_id}"
+
+ caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
+
+ return url, name, caption
+
+ def _build_attachment_document(
+ self,
+ block_id: str,
+ url: str,
+ name: str,
+ caption: Optional[str],
+ page_last_edited_time: Optional[str],
+ ) -> Document | None:
+ file_bytes = self._download_file(url)
+ if file_bytes is None:
+ return None
+
+ extension = Path(name).suffix or Path(urlparse(url).path).suffix or ".bin"
+ if extension and not extension.startswith("."):
+ extension = f".{extension}"
+ if not extension:
+ extension = ".bin"
+
+ updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
+ semantic_identifier = caption or name or f"Notion file {block_id}"
+
+ return Document(
+ id=block_id,
+ blob=file_bytes,
+ source=DocumentSource.NOTION,
+ semantic_identifier=semantic_identifier,
+ extension=extension,
+ size_bytes=len(file_bytes),
+ doc_updated_at=updated_at,
+ )
+
+ def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
result_blocks: list[NotionBlock] = []
child_pages: list[str] = []
+ attachments: list[Document] = []
cursor = None
while True:
data = self._fetch_child_blocks(base_block_id, cursor)
if data is None:
- return result_blocks, child_pages
+ return result_blocks, child_pages, attachments
for result in data["results"]:
- logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
+ logging.debug(f"[Notion]: Found child block for block with ID {base_block_id}: {result}")
result_block_id = result["id"]
result_type = result["type"]
result_obj = result[result_type]
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
- logging.warning(f"Skipping unsupported block type '{result_type}'")
+ logging.warning(f"[Notion]: Skipping unsupported block type {result_type}")
+ continue
+
+ if result_type == "table":
+ table_html = self._build_table_html(result_block_id)
+ if table_html:
+ result_blocks.append(
+ NotionBlock(
+ id=result_block_id,
+ text=table_html,
+ prefix="\n\n",
+ )
+ )
+ continue
+
+ if result_type == "equation":
+ expr = result_obj.get("expression")
+ if expr:
+ result_blocks.append(
+ NotionBlock(
+ id=result_block_id,
+ text=expr,
+ prefix="\n",
+ )
+ )
continue
cur_result_text_arr = []
if "rich_text" in result_obj:
- for rich_text in result_obj["rich_text"]:
- if "text" in rich_text:
- text = rich_text["text"]["content"]
- cur_result_text_arr.append(text)
+ text = self._extract_rich_text(result_obj["rich_text"])
+ if text:
+ cur_result_text_arr.append(text)
+
+ if result_type == "bulleted_list_item":
+ if cur_result_text_arr:
+ cur_result_text_arr[0] = f"- {cur_result_text_arr[0]}"
+ else:
+ cur_result_text_arr = ["- "]
+
+ if result_type == "numbered_list_item":
+ if cur_result_text_arr:
+ cur_result_text_arr[0] = f"1. {cur_result_text_arr[0]}"
+ else:
+ cur_result_text_arr = ["1. "]
+
+ if result_type == "to_do":
+ checked = result_obj.get("checked")
+ checkbox_prefix = "[x]" if checked else "[ ]"
+ if cur_result_text_arr:
+ cur_result_text_arr = [f"{checkbox_prefix} {cur_result_text_arr[0]}"] + cur_result_text_arr[1:]
+ else:
+ cur_result_text_arr = [checkbox_prefix]
+
+ if result_type in {"file", "image", "pdf", "video", "audio"}:
+ file_url, file_name, caption = self._extract_file_metadata(result_obj, result_block_id)
+ if file_url:
+ attachment_doc = self._build_attachment_document(
+ block_id=result_block_id,
+ url=file_url,
+ name=file_name,
+ caption=caption,
+ page_last_edited_time=page_last_edited_time,
+ )
+ if attachment_doc:
+ attachments.append(attachment_doc)
+
+ attachment_label = caption or file_name
+ if attachment_label:
+ cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
if result["has_children"]:
if result_type == "child_page":
child_pages.append(result_block_id)
else:
- logging.debug(f"Entering sub-block: {result_block_id}")
- subblocks, subblock_child_pages = self._read_blocks(result_block_id)
- logging.debug(f"Finished sub-block: {result_block_id}")
+ logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
+ subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
+ logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages)
+ attachments.extend(subblock_attachments)
if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
@@ -231,7 +409,7 @@ class NotionConnector(LoadConnector, PollConnector):
cursor = data["next_cursor"]
- return result_blocks, child_pages
+ return result_blocks, child_pages, attachments
def _read_page_title(self, page: NotionPage) -> Optional[str]:
"""Extracts the title from a Notion page."""
@@ -245,9 +423,7 @@ class NotionConnector(LoadConnector, PollConnector):
return None
- def _read_pages(
- self, pages: list[NotionPage]
- ) -> Generator[Document, None, None]:
+ def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents."""
all_child_page_ids: list[str] = []
@@ -255,11 +431,17 @@ class NotionConnector(LoadConnector, PollConnector):
if isinstance(page, dict):
page = NotionPage(**page)
if page.id in self.indexed_pages:
- logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
+ logging.debug(f"[Notion]: Already indexed page with ID {page.id}. Skipping.")
continue
- logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
- page_blocks, child_page_ids = self._read_blocks(page.id)
+ if start is not None and end is not None:
+ page_ts = datetime_from_string(page.last_edited_time).timestamp()
+ if not (page_ts > start and page_ts <= end):
+ logging.debug(f"[Notion]: Skipping page {page.id} outside polling window.")
+ continue
+
+ logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
+ page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
all_child_page_ids.extend(child_page_ids)
self.indexed_pages.add(page.id)
@@ -268,14 +450,12 @@ class NotionConnector(LoadConnector, PollConnector):
if not page_blocks:
if not raw_page_title:
- logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
+ logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
continue
text = page_title
if page.properties:
- text += "\n\n" + "\n".join(
- [f"{key}: {value}" for key, value in page.properties.items()]
- )
+ text += "\n\n" + "\n".join([f"{key}: {value}" for key, value in page.properties.items()])
sections = [TextSection(link=page.url, text=text)]
else:
sections = [
@@ -286,45 +466,39 @@ class NotionConnector(LoadConnector, PollConnector):
for block in page_blocks
]
- blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
+ joined_text = "\n".join(sec.text for sec in sections)
+ blob = joined_text.encode("utf-8")
yield Document(
- id=page.id,
- blob=blob,
- source=DocumentSource.NOTION,
- semantic_identifier=page_title,
- extension=".txt",
- size_bytes=len(blob),
- doc_updated_at=datetime_from_string(page.last_edited_time)
+ id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
)
+ for attachment_doc in attachment_docs:
+ yield attachment_doc
+
if self.recursive_index_enabled and all_child_page_ids:
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
- child_page_batch = [
- self._fetch_page(page_id)
- for page_id in child_page_batch_ids
- if page_id not in self.indexed_pages
- ]
- yield from self._read_pages(child_page_batch)
+ child_page_batch = [self._fetch_page(page_id) for page_id in child_page_batch_ids if page_id not in self.indexed_pages]
+ yield from self._read_pages(child_page_batch, start, end)
@retry(tries=3, delay=1, backoff=2)
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
"""Search for pages from a Notion database."""
- logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
+ logging.debug(f"[Notion]: Searching for pages in Notion with query_dict: {query_dict}")
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
return NotionSearchResponse(**data)
- def _recursive_load(self) -> Generator[list[Document], None, None]:
+ def _recursive_load(self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[list[Document], None, None]:
"""Recursively load pages starting from root page ID."""
if self.root_page_id is None or not self.recursive_index_enabled:
raise RuntimeError("Recursive page lookup is not enabled")
- logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
+ logging.info(f"[Notion]: Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
pages = [self._fetch_page(page_id=self.root_page_id)]
- yield from batch_generator(self._read_pages(pages), self.batch_size)
+ yield from batch_generator(self._read_pages(pages, start, end), self.batch_size)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Applies integration token to headers."""
- self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
+ self.headers["Authorization"] = f"Bearer {credentials['notion_integration_token']}"
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -348,12 +522,10 @@ class NotionConnector(LoadConnector, PollConnector):
else:
break
- def poll_source(
- self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
- ) -> GenerateDocumentsOutput:
+ def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
"""Poll Notion for updated pages within a time period."""
if self.recursive_index_enabled and self.root_page_id:
- yield from self._recursive_load()
+ yield from self._recursive_load(start, end)
return
query_dict = {
@@ -367,7 +539,7 @@ class NotionConnector(LoadConnector, PollConnector):
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
if pages:
- yield from batch_generator(self._read_pages(pages), self.batch_size)
+ yield from batch_generator(self._read_pages(pages, start, end), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py
index 59fec9250..965f82265 100644
--- a/deepdoc/parser/docling_parser.py
+++ b/deepdoc/parser/docling_parser.py
@@ -187,7 +187,7 @@ class DoclingParser(RAGFlowPdfParser):
bbox = _BBox(int(pn), bb[0], bb[1], bb[2], bb[3])
yield (DoclingContentType.EQUATION.value, text, bbox)
- def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
+ def _transfer_to_sections(self, doc, parse_method: str) -> list[tuple[str, str]]:
sections: list[tuple[str, str]] = []
for typ, payload, bbox in self._iter_doc_items(doc):
if typ == DoclingContentType.TEXT.value:
@@ -200,7 +200,12 @@ class DoclingParser(RAGFlowPdfParser):
continue
tag = self._make_line_tag(bbox) if isinstance(bbox,_BBox) else ""
- sections.append((section, tag))
+ if parse_method == "manual":
+ sections.append((section, typ, tag))
+ elif parse_method == "paper":
+ sections.append((section + tag, typ))
+ else:
+ sections.append((section, tag))
return sections
def cropout_docling_table(self, page_no: int, bbox: tuple[float, float, float, float], zoomin: int = 1):
@@ -282,7 +287,8 @@ class DoclingParser(RAGFlowPdfParser):
output_dir: Optional[str] = None,
lang: Optional[str] = None,
method: str = "auto",
- delete_output: bool = True,
+ delete_output: bool = True,
+ parse_method: str = "raw"
):
if not self.check_installation():
@@ -318,7 +324,7 @@ class DoclingParser(RAGFlowPdfParser):
if callback:
callback(0.7, f"[Docling] Parsed doc: {getattr(doc, 'num_pages', 'n/a')} pages")
- sections = self._transfer_to_sections(doc)
+ sections = self._transfer_to_sections(doc, parse_method=parse_method)
tables = self._transfer_to_tables(doc)
if callback:
diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py
index d2b694188..d4834de39 100644
--- a/deepdoc/parser/mineru_parser.py
+++ b/deepdoc/parser/mineru_parser.py
@@ -476,7 +476,7 @@ class MinerUParser(RAGFlowPdfParser):
item[key] = str((subdir / item[key]).resolve())
return data
- def _transfer_to_sections(self, outputs: list[dict[str, Any]]):
+ def _transfer_to_sections(self, outputs: list[dict[str, Any]], parse_method: str = None):
sections = []
for output in outputs:
match output["type"]:
@@ -497,7 +497,11 @@ class MinerUParser(RAGFlowPdfParser):
case MinerUContentType.DISCARDED:
pass
- if section:
+ if section and parse_method == "manual":
+ sections.append((section, output["type"], self._line_tag(output)))
+ elif section and parse_method == "paper":
+ sections.append((section + self._line_tag(output), output["type"]))
+ else:
sections.append((section, self._line_tag(output)))
return sections
@@ -516,6 +520,7 @@ class MinerUParser(RAGFlowPdfParser):
method: str = "auto",
server_url: Optional[str] = None,
delete_output: bool = True,
+ parse_method: str = "raw"
) -> tuple:
import shutil
@@ -565,7 +570,8 @@ class MinerUParser(RAGFlowPdfParser):
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
- return self._transfer_to_sections(outputs), self._transfer_to_tables(outputs)
+
+ return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs)
finally:
if temp_pdf and temp_pdf.exists():
try:
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index 5bc877a6a..6d8431c82 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -33,6 +33,8 @@ import xgboost as xgb
from huggingface_hub import snapshot_download
from PIL import Image
from pypdf import PdfReader as pdf2_read
+from sklearn.cluster import KMeans
+from sklearn.metrics import silhouette_score
from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch
@@ -353,7 +355,6 @@ class RAGFlowPdfParser:
def _assign_column(self, boxes, zoomin=3):
if not boxes:
return boxes
-
if all("col_id" in b for b in boxes):
return boxes
@@ -361,61 +362,80 @@ class RAGFlowPdfParser:
for b in boxes:
by_page[b["page_number"]].append(b)
- page_info = {} # pg -> dict(page_w, left_edge, cand_cols)
- counter = Counter()
+ page_cols = {}
for pg, bxs in by_page.items():
if not bxs:
- page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1}
- counter[1] += 1
+ page_cols[pg] = 1
continue
- if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg:
- page_w = self.page_images[pg - 1].size[0] / max(1, zoomin)
- left_edge = 0.0
- else:
- xs0 = [box["x0"] for box in bxs]
- xs1 = [box["x1"] for box in bxs]
- left_edge = float(min(xs0))
- page_w = max(1.0, float(max(xs1) - left_edge))
+ x0s_raw = np.array([b["x0"] for b in bxs], dtype=float)
- widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs]
- median_w = float(np.median(widths)) if widths else 1.0
+ min_x0 = np.min(x0s_raw)
+ max_x1 = np.max([b["x1"] for b in bxs])
+ width = max_x1 - min_x0
- raw_cols = int(page_w / max(1.0, median_w))
+ INDENT_TOL = width * 0.12
+ x0s = []
+ for x in x0s_raw:
+ if abs(x - min_x0) < INDENT_TOL:
+ x0s.append([min_x0])
+ else:
+ x0s.append([x])
+ x0s = np.array(x0s, dtype=float)
+
+ max_try = min(4, len(bxs))
+ if max_try < 2:
+ max_try = 1
+ best_k = 1
+ best_score = -1
- # cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1
- cand = raw_cols
+ for k in range(1, max_try + 1):
+ km = KMeans(n_clusters=k, n_init="auto")
+ labels = km.fit_predict(x0s)
- page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand}
- counter[cand] += 1
+ centers = np.sort(km.cluster_centers_.flatten())
+ if len(centers) > 1:
+ try:
+ score = silhouette_score(x0s, labels)
+ except ValueError:
+ continue
+ else:
+ score = 0
+ print(f"{k=},{score=}",flush=True)
+ if score > best_score:
+ best_score = score
+ best_k = k
- logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}")
+ page_cols[pg] = best_k
+ logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
- global_cols = counter.most_common(1)[0][0]
+
+ global_cols = Counter(page_cols.values()).most_common(1)[0][0]
logging.info(f"Global column_num decided by majority: {global_cols}")
+
for pg, bxs in by_page.items():
if not bxs:
continue
+ k = page_cols[pg]
+ if len(bxs) < k:
+ k = 1
+ x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
+ km = KMeans(n_clusters=k, n_init="auto")
+ labels = km.fit_predict(x0s)
- page_w = page_info[pg]["page_w"]
- left_edge = page_info[pg]["left_edge"]
+ centers = km.cluster_centers_.flatten()
+ order = np.argsort(centers)
- if global_cols == 1:
- for box in bxs:
- box["col_id"] = 0
- continue
+ remap = {orig: new for new, orig in enumerate(order)}
- for box in bxs:
- w = box["x1"] - box["x0"]
- if w >= 0.8 * page_w:
- box["col_id"] = 0
- continue
- cx = 0.5 * (box["x0"] + box["x1"])
- norm_cx = (cx - left_edge) / page_w
- norm_cx = max(0.0, min(norm_cx, 0.999999))
- box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols))
+ for b, lb in zip(bxs, labels):
+ b["col_id"] = remap[lb]
+
+ grouped = defaultdict(list)
+ for b in bxs:
+ grouped[b["col_id"]].append(b)
return boxes
@@ -1303,7 +1323,10 @@ class RAGFlowPdfParser:
positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss):
- right = left + max_width
+ if 0 < ii < len(poss) - 1:
+ right = max(left + 10, right)
+ else:
+ right = left + max_width
bottom *= ZM
for pn in pns[1:]:
if 0 <= pn - 1 < page_count:
diff --git a/docker/.env b/docker/.env
index d7e4b025f..6423b7824 100644
--- a/docker/.env
+++ b/docker/.env
@@ -230,9 +230,16 @@ REGISTER_ENABLED=1
# SANDBOX_MAX_MEMORY=256m # b, k, m, g
# SANDBOX_TIMEOUT=10s # s, m, 1m30s
-# Enable DocLing and Mineru
+# Enable DocLing
USE_DOCLING=false
+
+# Enable Mineru
USE_MINERU=false
+MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru"
+MINERU_DELETE_OUTPUT=0 # keep output directory
+MINERU_BACKEND=pipeline # or another backend you prefer
+
+
# pptx support
DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
\ No newline at end of file
diff --git a/rag/app/manual.py b/rag/app/manual.py
index 4f9de40c7..b3a4ae38d 100644
--- a/rag/app/manual.py
+++ b/rag/app/manual.py
@@ -214,6 +214,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback = callback,
pdf_cls = Pdf,
layout_recognizer = layout_recognizer,
+ parse_method = "manual",
**kwargs
)
@@ -226,7 +227,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif len(section) != 3:
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})")
- txt, sec_id, poss = section
+ txt, layoutno, poss = section
if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss)
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
@@ -236,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:])
- return (txt, sec_id, poss)
+ return (txt, layoutno, poss)
sections = [_normalize_section(sec) for sec in sections]
diff --git a/rag/app/naive.py b/rag/app/naive.py
index 49dca17af..562336d7f 100644
--- a/rag/app/naive.py
+++ b/rag/app/naive.py
@@ -59,6 +59,7 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
+ parse_method = kwargs.get("parse_method", "raw")
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
@@ -72,12 +73,14 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
server_url=os.environ.get("MINERU_SERVER_URL", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
+ parse_method=parse_method
)
return sections, tables, pdf_parser
def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, pdf_cls = None ,**kwargs):
pdf_parser = DoclingParser()
+ parse_method = kwargs.get("parse_method", "raw")
if not pdf_parser.check_installation():
callback(-1, "Docling not found.")
@@ -89,6 +92,7 @@ def by_docling(filename, binary=None, from_page=0, to_page=100000, lang="Chinese
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
+ parse_method=parse_method
)
return sections, tables, pdf_parser
diff --git a/rag/app/paper.py b/rag/app/paper.py
index d95976c9f..222be0762 100644
--- a/rag/app/paper.py
+++ b/rag/app/paper.py
@@ -21,8 +21,10 @@ import re
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from common.constants import ParserType
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
-from deepdoc.parser import PdfParser, PlainParser
+from deepdoc.parser import PdfParser
import numpy as np
+from rag.app.naive import by_plaintext, PARSERS
+
class Pdf(PdfParser):
def __init__(self):
@@ -147,19 +149,40 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
if re.search(r"\.pdf$", filename, re.IGNORECASE):
- if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
- pdf_parser = PlainParser()
+ layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
+
+ if isinstance(layout_recognizer, bool):
+ layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
+
+ name = layout_recognizer.strip().lower()
+ pdf_parser = PARSERS.get(name, by_plaintext)
+ callback(0.1, "Start to parse.")
+
+ if name == "deepdoc":
+ pdf_parser = Pdf()
+ paper = pdf_parser(filename if not binary else binary,
+ from_page=from_page, to_page=to_page, callback=callback)
+ else:
+ sections, tables, pdf_parser = pdf_parser(
+ filename=filename,
+ binary=binary,
+ from_page=from_page,
+ to_page=to_page,
+ lang=lang,
+ callback=callback,
+ pdf_cls=Pdf,
+ parse_method="paper",
+ **kwargs
+ )
+
paper = {
"title": filename,
"authors": " ",
"abstract": "",
- "sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0],
- "tables": []
+ "sections": sections,
+ "tables": tables
}
- else:
- pdf_parser = Pdf()
- paper = pdf_parser(filename if not binary else binary,
- from_page=from_page, to_page=to_page, callback=callback)
+
tbls=paper["tables"]
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
paper["tables"] = tbls
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index 4dbd9945c..a479e5d3f 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -355,75 +355,102 @@ class Dealer:
rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split())
- def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
- vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
- rerank_mdl=None, highlight=False,
- rank_feature: dict | None = {PAGERANK_FLD: 10}):
+ def retrieval(
+ self,
+ question,
+ embd_mdl,
+ tenant_ids,
+ kb_ids,
+ page,
+ page_size,
+ similarity_threshold=0.2,
+ vector_similarity_weight=0.3,
+ top=1024,
+ doc_ids=None,
+ aggs=True,
+ rerank_mdl=None,
+ highlight=False,
+ rank_feature: dict | None = {PAGERANK_FLD: 10},
+ ):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
# Ensure RERANK_LIMIT is multiple of page_size
- RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
- req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
- "question": question, "vector": True, "topk": top,
- "similarity": similarity_threshold,
- "available_int": 1}
-
+ RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
+ req = {
+ "kb_ids": kb_ids,
+ "doc_ids": doc_ids,
+ "page": math.ceil(page_size * page / RERANK_LIMIT),
+ "size": RERANK_LIMIT,
+ "question": question,
+ "vector": True,
+ "topk": top,
+ "similarity": similarity_threshold,
+ "available_int": 1,
+ }
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
- sres = self.search(req, [index_name(tid) for tid in tenant_ids],
- kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
+ sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
if rerank_mdl and sres.total > 0:
- sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
- sres, question, 1 - vector_similarity_weight,
- vector_similarity_weight,
- rank_feature=rank_feature)
+ sim, tsim, vsim = self.rerank_by_model(
+ rerank_mdl,
+ sres,
+ question,
+ 1 - vector_similarity_weight,
+ vector_similarity_weight,
+ rank_feature=rank_feature,
+ )
else:
- lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
- if lower_case_doc_engine in ["elasticsearch","opensearch"]:
+ lower_case_doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
+ if lower_case_doc_engine in ["elasticsearch", "opensearch"]:
# ElasticSearch doesn't normalize each way score before fusion.
sim, tsim, vsim = self.rerank(
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
- rank_feature=rank_feature)
+ sres,
+ question,
+ 1 - vector_similarity_weight,
+ vector_similarity_weight,
+ rank_feature=rank_feature,
+ )
else:
# Don't need rerank here since Infinity normalizes each way score before fusion.
sim = [sres.field[id].get("_score", 0.0) for id in sres.ids]
- sim = [s if s is not None else 0. for s in sim]
+ sim = [s if s is not None else 0.0 for s in sim]
tsim = sim
vsim = sim
- # Already paginated in search function
- max_pages = RERANK_LIMIT // page_size
- page_index = (page % max_pages) - 1
- begin = max(page_index * page_size, 0)
- sim = sim[begin : begin + page_size]
+
sim_np = np.array(sim, dtype=np.float64)
- idx = np.argsort(sim_np * -1)
+ if sim_np.size == 0:
+ return ranks
+
+ sorted_idx = np.argsort(sim_np * -1)
+
+ valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= similarity_threshold]
+ filtered_count = len(valid_idx)
+ ranks["total"] = int(filtered_count)
+
+ if filtered_count == 0:
+ return ranks
+
+ max_pages = max(RERANK_LIMIT // max(page_size, 1), 1)
+ page_index = (page - 1) % max_pages
+ begin = page_index * page_size
+ end = begin + page_size
+ page_idx = valid_idx[begin:end]
+
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
- filtered_count = (sim_np >= similarity_threshold).sum()
- ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
- for i in idx:
- if np.float64(sim[i]) < similarity_threshold:
- break
+ for i in page_idx:
id = sres.ids[i]
chunk = sres.field[id]
dnm = chunk.get("docnm_kwd", "")
did = chunk.get("doc_id", "")
- if len(ranks["chunks"]) >= page_size:
- if aggs:
- if dnm not in ranks["doc_aggs"]:
- ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
- ranks["doc_aggs"][dnm]["count"] += 1
- continue
- break
-
position_int = chunk.get("position_int", [])
d = {
"chunk_id": id,
@@ -434,12 +461,12 @@ class Dealer:
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
- "similarity": sim[i],
- "vector_similarity": vsim[i],
- "term_similarity": tsim[i],
+ "similarity": float(sim_np[i]),
+ "vector_similarity": float(vsim[i]),
+ "term_similarity": float(tsim[i]),
"vector": chunk.get(vector_column, zero_vector),
"positions": position_int,
- "doc_type_kwd": chunk.get("doc_type_kwd", "")
+ "doc_type_kwd": chunk.get("doc_type_kwd", ""),
}
if highlight and sres.highlight:
if id in sres.highlight:
@@ -447,15 +474,30 @@ class Dealer:
else:
d["highlight"] = d["content_with_weight"]
ranks["chunks"].append(d)
- if dnm not in ranks["doc_aggs"]:
- ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
- ranks["doc_aggs"][dnm]["count"] += 1
- ranks["doc_aggs"] = [{"doc_name": k,
- "doc_id": v["doc_id"],
- "count": v["count"]} for k,
- v in sorted(ranks["doc_aggs"].items(),
- key=lambda x: x[1]["count"] * -1)]
- ranks["chunks"] = ranks["chunks"][:page_size]
+
+ if aggs:
+ for i in valid_idx:
+ id = sres.ids[i]
+ chunk = sres.field[id]
+ dnm = chunk.get("docnm_kwd", "")
+ did = chunk.get("doc_id", "")
+ if dnm not in ranks["doc_aggs"]:
+ ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
+ ranks["doc_aggs"][dnm]["count"] += 1
+
+ ranks["doc_aggs"] = [
+ {
+ "doc_name": k,
+ "doc_id": v["doc_id"],
+ "count": v["count"],
+ }
+ for k, v in sorted(
+ ranks["doc_aggs"].items(),
+ key=lambda x: x[1]["count"] * -1,
+ )
+ ]
+ else:
+ ranks["doc_aggs"] = []
return ranks
@@ -564,7 +606,7 @@ class Dealer:
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
if not ids:
return chunks
-
+
vector_size = 1024
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
for cid, sim in ids:
diff --git a/web/package-lock.json b/web/package-lock.json
index e8d87301b..7bae0ed0f 100644
--- a/web/package-lock.json
+++ b/web/package-lock.json
@@ -66,6 +66,7 @@
"input-otp": "^1.4.1",
"js-base64": "^3.7.5",
"jsencrypt": "^3.3.2",
+ "jsoneditor": "^10.4.2",
"lexical": "^0.23.1",
"lodash": "^4.17.21",
"lucide-react": "^0.546.0",
@@ -8998,6 +8999,12 @@
"@sinonjs/commons": "^3.0.0"
}
},
+ "node_modules/@sphinxxxx/color-conversion": {
+ "version": "2.2.2",
+ "resolved": "https://registry.npmmirror.com/@sphinxxxx/color-conversion/-/color-conversion-2.2.2.tgz",
+ "integrity": "sha512-XExJS3cLqgrmNBIP3bBw6+1oQ1ksGjFh0+oClDKFYpCCqx/hlqwWO5KO/S63fzUo67SxI9dMrF0y5T/Ey7h8Zw==",
+ "license": "ISC"
+ },
"node_modules/@storybook/addon-docs": {
"version": "9.1.4",
"resolved": "https://registry.npmmirror.com/@storybook/addon-docs/-/addon-docs-9.1.4.tgz",
@@ -12962,6 +12969,12 @@
"node": ">= 0.6"
}
},
+ "node_modules/ace-builds": {
+ "version": "1.43.4",
+ "resolved": "https://registry.npmmirror.com/ace-builds/-/ace-builds-1.43.4.tgz",
+ "integrity": "sha512-8hAxVfo2ImICd69BWlZwZlxe9rxDGDjuUhh+WeWgGDvfBCE+r3lkynkQvIovDz4jcMi8O7bsEaFygaDT+h9sBA==",
+ "license": "BSD-3-Clause"
+ },
"node_modules/acorn": {
"version": "8.15.0",
"resolved": "https://registry.npmmirror.com/acorn/-/acorn-8.15.0.tgz",
@@ -21894,6 +21907,12 @@
"@pkgjs/parseargs": "^0.11.0"
}
},
+ "node_modules/javascript-natural-sort": {
+ "version": "0.7.1",
+ "resolved": "https://registry.npmmirror.com/javascript-natural-sort/-/javascript-natural-sort-0.7.1.tgz",
+ "integrity": "sha512-nO6jcEfZWQXDhOiBtG2KvKyEptz7RVbpGP4vTD2hLBdmNQSsCiicO2Ioinv6UI4y9ukqnBpy+XZ9H6uLNgJTlw==",
+ "license": "MIT"
+ },
"node_modules/javascript-stringify": {
"version": "2.1.0",
"resolved": "https://registry.npmmirror.com/javascript-stringify/-/javascript-stringify-2.1.0.tgz",
@@ -24253,6 +24272,15 @@
"jiti": "bin/jiti.js"
}
},
+ "node_modules/jmespath": {
+ "version": "0.16.0",
+ "resolved": "https://registry.npmmirror.com/jmespath/-/jmespath-0.16.0.tgz",
+ "integrity": "sha512-9FzQjJ7MATs1tSpnco1K6ayiYE3figslrXA72G2HQ/n76RzvYlofyi5QM+iX4YRs/pu3yzxlVQSST23+dMDknw==",
+ "license": "Apache-2.0",
+ "engines": {
+ "node": ">= 0.6.0"
+ }
+ },
"node_modules/js-base64": {
"version": "3.7.5",
"resolved": "https://registry.npmmirror.com/js-base64/-/js-base64-3.7.5.tgz",
@@ -24357,6 +24385,12 @@
"integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==",
"license": "MIT"
},
+ "node_modules/json-source-map": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmmirror.com/json-source-map/-/json-source-map-0.6.1.tgz",
+ "integrity": "sha512-1QoztHPsMQqhDq0hlXY5ZqcEdUzxQEIxgFkKl4WUp2pgShObl+9ovi4kRh2TfvAfxAoHOJ9vIMEqk3k4iex7tg==",
+ "license": "MIT"
+ },
"node_modules/json-stable-stringify-without-jsonify": {
"version": "1.0.1",
"resolved": "https://registry.npmmirror.com/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz",
@@ -24393,6 +24427,44 @@
"node": ">=6"
}
},
+ "node_modules/jsoneditor": {
+ "version": "10.4.2",
+ "resolved": "https://registry.npmmirror.com/jsoneditor/-/jsoneditor-10.4.2.tgz",
+ "integrity": "sha512-SQPCXlanU4PqdVsYuj2X7yfbLiiJYjklbksGfMKPsuwLhAIPxDlG43jYfXieGXvxpuq1fkw08YoRbkKXKabcLA==",
+ "license": "Apache-2.0",
+ "dependencies": {
+ "ace-builds": "^1.36.2",
+ "ajv": "^6.12.6",
+ "javascript-natural-sort": "^0.7.1",
+ "jmespath": "^0.16.0",
+ "json-source-map": "^0.6.1",
+ "jsonrepair": "^3.8.1",
+ "picomodal": "^3.0.0",
+ "vanilla-picker": "^2.12.3"
+ }
+ },
+ "node_modules/jsoneditor/node_modules/ajv": {
+ "version": "6.12.6",
+ "resolved": "https://registry.npmmirror.com/ajv/-/ajv-6.12.6.tgz",
+ "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
+ "license": "MIT",
+ "dependencies": {
+ "fast-deep-equal": "^3.1.1",
+ "fast-json-stable-stringify": "^2.0.0",
+ "json-schema-traverse": "^0.4.1",
+ "uri-js": "^4.2.2"
+ },
+ "funding": {
+ "type": "github",
+ "url": "https://github.com/sponsors/epoberezkin"
+ }
+ },
+ "node_modules/jsoneditor/node_modules/json-schema-traverse": {
+ "version": "0.4.1",
+ "resolved": "https://registry.npmmirror.com/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz",
+ "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==",
+ "license": "MIT"
+ },
"node_modules/jsonfile": {
"version": "6.1.0",
"resolved": "https://registry.npmmirror.com/jsonfile/-/jsonfile-6.1.0.tgz",
@@ -24404,6 +24476,15 @@
"graceful-fs": "^4.1.6"
}
},
+ "node_modules/jsonrepair": {
+ "version": "3.13.1",
+ "resolved": "https://registry.npmmirror.com/jsonrepair/-/jsonrepair-3.13.1.tgz",
+ "integrity": "sha512-WJeiE0jGfxYmtLwBTEk8+y/mYcaleyLXWaqp5bJu0/ZTSeG0KQq/wWQ8pmnkKenEdN6pdnn6QtcoSUkbqDHWNw==",
+ "license": "ISC",
+ "bin": {
+ "jsonrepair": "bin/cli.js"
+ }
+ },
"node_modules/jsx-ast-utils": {
"version": "3.3.5",
"resolved": "https://registry.npmmirror.com/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz",
@@ -27499,6 +27580,12 @@
"node": ">=8.6"
}
},
+ "node_modules/picomodal": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmmirror.com/picomodal/-/picomodal-3.0.0.tgz",
+ "integrity": "sha512-FoR3TDfuLlqUvcEeK5ifpKSVVns6B4BQvc8SDF6THVMuadya6LLtji0QgUDSStw0ZR2J7I6UGi5V2V23rnPWTw==",
+ "license": "MIT"
+ },
"node_modules/pidtree": {
"version": "0.6.0",
"resolved": "https://registry.npmmirror.com/pidtree/-/pidtree-0.6.0.tgz",
@@ -36235,6 +36322,15 @@
"dev": true,
"peer": true
},
+ "node_modules/vanilla-picker": {
+ "version": "2.12.3",
+ "resolved": "https://registry.npmmirror.com/vanilla-picker/-/vanilla-picker-2.12.3.tgz",
+ "integrity": "sha512-qVkT1E7yMbUsB2mmJNFmaXMWE2hF8ffqzMMwe9zdAikd8u2VfnsVY2HQcOUi2F38bgbxzlJBEdS1UUhOXdF9GQ==",
+ "license": "ISC",
+ "dependencies": {
+ "@sphinxxxx/color-conversion": "^2.2.2"
+ }
+ },
"node_modules/vary": {
"version": "1.1.2",
"resolved": "https://registry.npmmirror.com/vary/-/vary-1.1.2.tgz",
diff --git a/web/src/pages/agent/form/agent-form/index.tsx b/web/src/pages/agent/form/agent-form/index.tsx
index 2b23010cd..38c49b666 100644
--- a/web/src/pages/agent/form/agent-form/index.tsx
+++ b/web/src/pages/agent/form/agent-form/index.tsx
@@ -22,7 +22,8 @@ import { Switch } from '@/components/ui/switch';
import { LlmModelType } from '@/constants/knowledge';
import { useFindLlmByUuid } from '@/hooks/use-llm-request';
import { zodResolver } from '@hookform/resolvers/zod';
-import { memo, useCallback, useEffect, useMemo } from 'react';
+import { get } from 'lodash';
+import { memo, useEffect, useMemo } from 'react';
import { useForm, useWatch } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
@@ -45,7 +46,10 @@ import { AgentTools, Agents } from './agent-tools';
import { StructuredOutputDialog } from './structured-output-dialog';
import { StructuredOutputPanel } from './structured-output-panel';
import { useBuildPromptExtraPromptOptions } from './use-build-prompt-options';
-import { useShowStructuredOutputDialog } from './use-show-structured-output-dialog';
+import {
+ useHandleShowStructuredOutput,
+ useShowStructuredOutputDialog,
+} from './use-show-structured-output-dialog';
import { useValues } from './use-values';
import { useWatchFormChange } from './use-watch-change';
@@ -121,22 +125,19 @@ function AgentForm({ node }: INextOperatorForm) {
});
const {
- initialStructuredOutput,
showStructuredOutputDialog,
structuredOutputDialogVisible,
hideStructuredOutputDialog,
handleStructuredOutputDialogOk,
} = useShowStructuredOutputDialog(node?.id);
- const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
+ const structuredOutput = get(
+ node,
+ `data.form.outputs.${AgentStructuredOutputField}`,
+ );
- const handleShowStructuredOutput = useCallback(
- (val: boolean) => {
- if (node?.id && val) {
- updateNodeForm(node?.id, {}, ['outputs', AgentStructuredOutputField]);
- }
- },
- [node?.id, updateNodeForm],
+ const { handleShowStructuredOutput } = useHandleShowStructuredOutput(
+ node?.id,
);
useEffect(() => {
@@ -327,7 +328,7 @@ function AgentForm({ node }: INextOperatorForm) {
)}
@@ -337,7 +338,7 @@ function AgentForm({ node }: INextOperatorForm) {
)}
>
diff --git a/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts b/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts
index 19e38cefe..d66fcfb45 100644
--- a/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts
+++ b/web/src/pages/agent/form/agent-form/use-show-structured-output-dialog.ts
@@ -1,6 +1,8 @@
import { JSONSchema } from '@/components/jsonjoy-builder';
+import { AgentStructuredOutputField } from '@/constants/agent';
import { useSetModalState } from '@/hooks/common-hooks';
import { useCallback } from 'react';
+import { initialAgentValues } from '../../constant';
import useGraphStore from '../../store';
export function useShowStructuredOutputDialog(nodeId?: string) {
@@ -9,15 +11,13 @@ export function useShowStructuredOutputDialog(nodeId?: string) {
showModal: showStructuredOutputDialog,
hideModal: hideStructuredOutputDialog,
} = useSetModalState();
- const { updateNodeForm, getNode } = useGraphStore((state) => state);
-
- const initialStructuredOutput = getNode(nodeId)?.data.form.outputs.structured;
+ const { updateNodeForm } = useGraphStore((state) => state);
const handleStructuredOutputDialogOk = useCallback(
(values: JSONSchema) => {
// Sync data to canvas
if (nodeId) {
- updateNodeForm(nodeId, values, ['outputs', 'structured']);
+ updateNodeForm(nodeId, values, ['outputs', AgentStructuredOutputField]);
}
hideStructuredOutputDialog();
},
@@ -25,10 +25,30 @@ export function useShowStructuredOutputDialog(nodeId?: string) {
);
return {
- initialStructuredOutput,
structuredOutputDialogVisible,
showStructuredOutputDialog,
hideStructuredOutputDialog,
handleStructuredOutputDialogOk,
};
}
+
+export function useHandleShowStructuredOutput(nodeId?: string) {
+ const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
+
+ const handleShowStructuredOutput = useCallback(
+ (val: boolean) => {
+ if (nodeId) {
+ if (val) {
+ updateNodeForm(nodeId, {}, ['outputs', AgentStructuredOutputField]);
+ } else {
+ updateNodeForm(nodeId, initialAgentValues.outputs, ['outputs']);
+ }
+ }
+ },
+ [nodeId, updateNodeForm],
+ );
+
+ return {
+ handleShowStructuredOutput,
+ };
+}
diff --git a/web/src/pages/agent/form/agent-form/use-values.ts b/web/src/pages/agent/form/agent-form/use-values.ts
index f8747e4b4..fb7f94861 100644
--- a/web/src/pages/agent/form/agent-form/use-values.ts
+++ b/web/src/pages/agent/form/agent-form/use-values.ts
@@ -6,8 +6,10 @@ import { initialAgentValues } from '../../constant';
// You need to exclude the mcp and tools fields that are not in the form,
// otherwise the form data update will reset the tools or mcp data to an array
+// Exclude data that is not in the form to avoid writing this data to the canvas when using useWatch.
+// Outputs, tools, and MCP data are directly synchronized to the canvas without going through the form.
function omitToolsAndMcp(values: Record) {
- return omit(values, ['mcp', 'tools']);
+ return omit(values, ['mcp', 'tools', 'outputs']);
}
export function useValues(node?: RAGFlowNodeType) {
diff --git a/web/src/pages/agent/form/agent-form/use-watch-change.ts b/web/src/pages/agent/form/agent-form/use-watch-change.ts
index 7c53a8d40..98b0ecf31 100644
--- a/web/src/pages/agent/form/agent-form/use-watch-change.ts
+++ b/web/src/pages/agent/form/agent-form/use-watch-change.ts
@@ -1,7 +1,6 @@
-import { omit } from 'lodash';
import { useEffect } from 'react';
import { UseFormReturn, useWatch } from 'react-hook-form';
-import { AgentStructuredOutputField, PromptRole } from '../../constant';
+import { PromptRole } from '../../constant';
import useGraphStore from '../../store';
export function useWatchFormChange(id?: string, form?: UseFormReturn) {
@@ -17,14 +16,6 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn) {
prompts: [{ role: PromptRole.User, content: values.prompts }],
};
- if (!values.showStructuredOutput) {
- nextValues = {
- ...nextValues,
- outputs: omit(values.outputs, [AgentStructuredOutputField]),
- };
- } else {
- nextValues = omit(nextValues, 'outputs');
- }
updateNodeForm(id, nextValues);
}
}, [form?.formState.isDirty, id, updateNodeForm, values]);