diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 660530c82..838246731 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -194,7 +194,19 @@ class SyncLogsService(CommonService): update_date=timestamp_to_date(current_timestamp()) )\ .where(cls.model.id == id).execute() - + + @classmethod + def increase_deleted_docs(cls, min_update, max_update, doc_num, err_msg="", error_count=0): + cls.model.update( + docs_removed_from_index=cls.model.docs_removed_from_index + doc_num, + total_docs_indexed=cls.model.total_docs_indexed - doc_num, + poll_range_start=fn.COALESCE(fn.LEAST(cls.model.poll_range_start,min_update), min_update), + poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update), + error_msg=cls.model.error_msg + err_msg, + error_count=cls.model.error_count + error_count, + update_time=current_timestamp(), + update_date=timestamp_to_date(current_timestamp()) + ) @classmethod def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True): from api.db.services.file_service import FileService diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7b7ef53ec..beedc48d8 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -750,6 +750,22 @@ class DocumentService(CommonService): for row in rows: result[row.kb_id] = row.count return result + + @classmethod + @DB.connection_context() + def list_document_ids_by_src(cls, tenant_id, kb, src): + fields = [cls.model.id] + docs = cls.model.select(*fields)\ + .where( + (cls.model.kb_id == kb), + (cls.model.source_type == src) + ) + + res = [] + for doc in docs: + res.append(doc.id) + + return res @classmethod @DB.connection_context() @@ -1028,3 +1044,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) return [d["id"] for d, _ in files] + diff --git a/common/settings.py b/common/settings.py index 81a2c19a4..e5da8cf85 100644 --- a/common/settings.py +++ b/common/settings.py @@ -85,6 +85,8 @@ kg_retriever = None # user registration switch REGISTER_ENABLED = 1 +ENABLE_SYNC_DELETED_CHANGE = os.getenv('ENABLE_SYNC_DELETED_CHANGE', False) + # sandbox-executor-manager SANDBOX_HOST = None diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index a5c4c4cfb..f9a67be4f 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -35,6 +35,8 @@ import trio from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.document_service import DocumentService +from api.db.services.file_service import FileService from common import settings from common.config_utils import show_configs from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector, MoodleConnector, JiraConnector, DropboxConnector, WebDAVConnector @@ -62,6 +64,13 @@ class SyncBase: SyncLogsService.start(task["id"], task["connector_id"]) try: async with task_limiter: + synced_doc_ids = set() #synced document ids for this sync task + source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" + existing_doc_ids = [] + with trio.fail_after(task["timeout_secs"]): + # get current synced docs from last sync + existing_doc_ids = DocumentService.list_documents_by_source(task["tenant_id"], task["kb_id"], source_type) + with trio.fail_after(task["timeout_secs"]): document_batch_generator = await self._generate(task) doc_num = 0 @@ -112,6 +121,42 @@ class SyncBase: failed_docs += len(docs) continue + if settings.ENABLE_SYNC_DELETED_CHANGE: + task_copy = copy.deepcopy(task) + task_copy.pop("poll_range_start", None) + document_batch_generator = await self._generate(task) + for document_batch in document_batch_generator: + if not document_batch: + continue + docs = [ + { + "id": doc.id, + "connector_id": task["connector_id"], + "source": self.SOURCE_NAME, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob, + } + for doc in document_batch + ] + + for doc in docs: + synced_doc_ids.add(doc["id"]) + + # delete removed docs + if not existing_doc_ids: + to_delete_ids = [] + for doc_id in existing_doc_ids: + if doc_id not in synced_doc_ids: + to_delete_ids.append(doc_id) + + if to_delete_ids: + FileService.delete_docs(to_delete_ids, task["tenant_id"]) + SyncLogsService.increase_deleted_docs(task["id"], len(to_delete_ids)) + logging.info(f"Deleted {len(to_delete_ids)} documents from knowledge base {task['kb_id']} for connector {task['connector_id']}") + prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else "" if failed_docs > 0: logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")