Merge e86fff19a8 into a2e080c2d3
This commit is contained in:
commit
9508ffbec7
6 changed files with 232 additions and 65 deletions
|
|
@ -120,55 +120,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
|||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
# Collect all objects first to count filename occurrences
|
||||
all_objects = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
if start < last_modified <= end:
|
||||
all_objects.append(obj)
|
||||
|
||||
# Count filename occurrences to determine which need full paths
|
||||
filename_counts: dict[str, int] = {}
|
||||
for obj in all_objects:
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
batch: list[Document] = []
|
||||
for obj in all_objects:
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(file_name, 0) > 1:
|
||||
relative_path = key
|
||||
if self.prefix and key.startswith(self.prefix):
|
||||
relative_path = key[len(self.prefix):]
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||
else:
|
||||
semantic_id = file_name
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ _PAGE_EXPANSION_FIELDS = [
|
|||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
"ancestors",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
|
|||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
# Track document names to detect duplicates
|
||||
self._document_name_counts: dict[str, int] = {}
|
||||
self._document_name_paths: dict[str, list[str]] = {}
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
|
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
|
|||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build hierarchical path for semantic identifier
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
|
||||
# Build path from ancestors
|
||||
path_parts = []
|
||||
if space_name:
|
||||
path_parts.append(space_name)
|
||||
|
||||
# Add ancestor pages to path if available
|
||||
if "ancestors" in page and page["ancestors"]:
|
||||
for ancestor in page["ancestors"]:
|
||||
ancestor_title = ancestor.get("title", "")
|
||||
if ancestor_title:
|
||||
path_parts.append(ancestor_title)
|
||||
|
||||
# Add current page title
|
||||
path_parts.append(page_title)
|
||||
|
||||
# Track page names for duplicate detection
|
||||
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
|
||||
|
||||
# Count occurrences of this page title
|
||||
if page_title not in self._document_name_counts:
|
||||
self._document_name_counts[page_title] = 0
|
||||
self._document_name_paths[page_title] = []
|
||||
self._document_name_counts[page_title] += 1
|
||||
self._document_name_paths[page_title].append(full_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[page_title] == 1:
|
||||
semantic_identifier = page_title
|
||||
else:
|
||||
semantic_identifier = full_path
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
self.confluence_client, page, self._fetched_titles
|
||||
|
|
@ -1559,7 +1596,7 @@ class ConfluenceConnector(
|
|||
return Document(
|
||||
id=page_url,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".html", # Confluence pages are HTML
|
||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
|
|
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
|
|||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
|
|
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
|
|||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build semantic identifier with space and page context
|
||||
attachment_title = attachment.get("title", object_url)
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
page_title = page.get("title", "")
|
||||
|
||||
# Create hierarchical name: Space / Page / Attachment
|
||||
attachment_path_parts = []
|
||||
if space_name:
|
||||
attachment_path_parts.append(space_name)
|
||||
if page_title:
|
||||
attachment_path_parts.append(page_title)
|
||||
attachment_path_parts.append(attachment_title)
|
||||
|
||||
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
|
||||
|
||||
# Track attachment names for duplicate detection
|
||||
if attachment_title not in self._document_name_counts:
|
||||
self._document_name_counts[attachment_title] = 0
|
||||
self._document_name_paths[attachment_title] = []
|
||||
self._document_name_counts[attachment_title] += 1
|
||||
self._document_name_paths[attachment_title].append(full_attachment_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[attachment_title] == 1:
|
||||
attachment_semantic_identifier = attachment_title
|
||||
else:
|
||||
attachment_semantic_identifier = full_attachment_path
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
|
|
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
|
|||
|
||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
# sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
semantic_identifier=attachment_semantic_identifier,
|
||||
extension=extension,
|
||||
blob=file_blob,
|
||||
size_bytes=len(file_blob),
|
||||
|
|
@ -1741,7 +1806,7 @@ class ConfluenceConnector(
|
|||
start_ts, end, self.batch_size
|
||||
)
|
||||
logging.debug(f"page_query_url: {page_query_url}")
|
||||
|
||||
|
||||
# store the next page start for confluence server, cursor for confluence cloud
|
||||
def store_next_page_url(next_page_url: str) -> None:
|
||||
checkpoint.next_page_url = next_page_url
|
||||
|
|
|
|||
|
|
@ -87,15 +87,69 @@ class DropboxConnector(LoadConnector, PollConnector):
|
|||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
# Collect all files first to count filename occurrences
|
||||
all_files = []
|
||||
self._collect_files_recursive(path, start, end, all_files)
|
||||
|
||||
# Count filename occurrences
|
||||
filename_counts: dict[str, int] = {}
|
||||
for entry, _ in all_files:
|
||||
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
|
||||
|
||||
# Process files in batches
|
||||
batch: list[Document] = []
|
||||
for entry, downloaded_file in all_files:
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(entry.name, 0) > 1:
|
||||
# Remove leading slash and replace slashes with ' / '
|
||||
relative_path = entry.path_display.lstrip('/')
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
|
||||
else:
|
||||
semantic_id = entry.name
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def _collect_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
all_files: list,
|
||||
) -> None:
|
||||
"""Recursively collect all files matching time criteria."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
|
|
@ -112,27 +166,13 @@ class DropboxConnector(LoadConnector, PollConnector):
|
|||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
all_files.append((entry, downloaded_file))
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
self._collect_files_recursive(entry.path_lower, start, end, all_files)
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -447,6 +447,17 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
# Count attachment semantic_identifier occurrences within this page
|
||||
attachment_name_counts: dict[str, int] = {}
|
||||
for att_doc in attachment_docs:
|
||||
name = att_doc.semantic_identifier
|
||||
attachment_name_counts[name] = attachment_name_counts.get(name, 0) + 1
|
||||
|
||||
# Update semantic identifiers for duplicate attachments
|
||||
for att_doc in attachment_docs:
|
||||
if attachment_name_counts.get(att_doc.semantic_identifier, 0) > 1:
|
||||
att_doc.semantic_identifier = f"{page_title} / {att_doc.semantic_identifier}"
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ from common.data_source.confluence_connector import ConfluenceConnector
|
|||
from common.data_source.gmail_connector import GmailConnector
|
||||
from common.data_source.box_connector import BoxConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.log_utils import init_root_logger
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.versions import get_ragflow_version
|
||||
|
|
@ -226,14 +225,48 @@ class Confluence(SyncBase):
|
|||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
document_generator = load_all_docs_from_checkpoint_connector(
|
||||
connector=self.connector,
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
)
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Confluence documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
async def async_wrapper():
|
||||
for batch in document_batches():
|
||||
yield batch
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return [document_generator]
|
||||
return async_wrapper()
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue