From 94dccd827054f7ca639a7c4f0d2a5b4750b27a88 Mon Sep 17 00:00:00 2001 From: buua436 Date: Tue, 9 Dec 2025 13:03:37 +0800 Subject: [PATCH] Refa:replace trio with asyncio in agent/deepdoc/rag --- agent/component/base.py | 3 +- deepdoc/parser/pdf_parser.py | 45 ++-- deepdoc/vision/t_ocr.py | 34 +-- rag/flow/base.py | 10 +- .../hierarchical_merger.py | 17 +- rag/flow/parser/parser.py | 18 +- rag/flow/pipeline.py | 7 +- rag/flow/splitter/splitter.py | 15 +- rag/flow/tests/client.py | 5 +- rag/flow/tokenizer/tokenizer.py | 4 +- rag/prompts/generator.py | 20 +- rag/svr/sync_data_source.py | 197 +++++++++++------- rag/svr/task_executor.py | 151 ++++++++++---- rag/utils/base64_image.py | 59 ++++-- rag/utils/redis_conn.py | 4 +- 15 files changed, 382 insertions(+), 207 deletions(-) diff --git a/agent/component/base.py b/agent/component/base.py index 6ac95e09a..81d3fac56 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -24,7 +24,6 @@ import os import logging from typing import Any, List, Union import pandas as pd -import trio from agent import settings from common.connection_utils import timeout @@ -393,7 +392,7 @@ class ComponentParamBase(ABC): class ComponentBase(ABC): component_name: str - thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) + thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" def __str__(self): diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index a6c370a7e..6b8a75a8d 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import logging import math import os @@ -28,7 +29,6 @@ from timeit import default_timer as timer import numpy as np import pdfplumber -import trio import xgboost as xgb from huggingface_hub import snapshot_download from PIL import Image @@ -66,7 +66,7 @@ class RAGFlowPdfParser: self.ocr = OCR() self.parallel_limiter = None if settings.PARALLEL_DEVICES > 1: - self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)] + self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)] layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() if layout_recognizer_type not in ["onnx", "ascend"]: @@ -383,7 +383,7 @@ class RAGFlowPdfParser: else: x0s.append([x]) x0s = np.array(x0s, dtype=float) - + max_try = min(4, len(bxs)) if max_try < 2: max_try = 1 @@ -417,7 +417,7 @@ class RAGFlowPdfParser: for pg, bxs in by_page.items(): if not bxs: continue - k = page_cols[pg] + k = page_cols[pg] if len(bxs) < k: k = 1 x0s = np.array([[b["x0"]] for b in bxs], dtype=float) @@ -431,7 +431,7 @@ class RAGFlowPdfParser: for b, lb in zip(bxs, labels): b["col_id"] = remap[lb] - + grouped = defaultdict(list) for b in bxs: grouped[b["col_id"]].append(b) @@ -1112,7 +1112,7 @@ class RAGFlowPdfParser: if limiter: async with limiter: - await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) + await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id) else: self.__ocr(i + 1, img, chars, zoomin, id) @@ -1128,12 +1128,33 @@ class RAGFlowPdfParser: return chars if self.parallel_limiter: - async with trio.open_nursery() as nursery: - for i, img in enumerate(self.page_images): - chars = __ocr_preprocess() + tasks = [] + + for i, img in enumerate(self.page_images): + chars = __ocr_preprocess() + + semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES] + + async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore): + await __img_ocr( + i, + i % settings.PARALLEL_DEVICES, + img, + chars, + semaphore, + ) + + tasks.append(asyncio.create_task(wrapper())) + await asyncio.sleep(0) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES]) - await trio.sleep(0.1) else: for i, img in enumerate(self.page_images): chars = __ocr_preprocess() @@ -1141,7 +1162,7 @@ class RAGFlowPdfParser: start = timer() - trio.run(__img_ocr_launcher) + asyncio.run(__img_ocr_launcher()) logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py index ccc96538b..347a3e9a8 100644 --- a/deepdoc/vision/t_ocr.py +++ b/deepdoc/vision/t_ocr.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import os import sys sys.path.insert( @@ -28,7 +29,6 @@ from deepdoc.vision.seeit import draw_box from deepdoc.vision import OCR, init_in_out import argparse import numpy as np -import trio # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu @@ -39,7 +39,7 @@ def main(args): import torch.cuda cuda_devices = torch.cuda.device_count() - limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None + limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None ocr = OCR() images, outputs = init_in_out(args) @@ -62,22 +62,28 @@ def main(args): async def __ocr_thread(i, id, img, limiter = None): if limiter: async with limiter: - print("Task {} use device {}".format(i, id)) - await trio.to_thread.run_sync(lambda: __ocr(i, id, img)) + print(f"Task {i} use device {id}") + await asyncio.to_thread(__ocr, i, id, img) else: - __ocr(i, id, img) + await asyncio.to_thread(__ocr, i, id, img) + async def __ocr_launcher(): - if cuda_devices > 1: - async with trio.open_nursery() as nursery: - for i, img in enumerate(images): - nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices]) - await trio.sleep(0.1) - else: - for i, img in enumerate(images): - await __ocr_thread(i, 0, img) + tasks = [] + for i, img in enumerate(images): + dev_id = i % cuda_devices if cuda_devices > 1 else 0 + semaphore = limiter[dev_id] if limiter else None + tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore))) - trio.run(__ocr_launcher) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + asyncio.run(__ocr_launcher()) print("OCR tasks are all done") diff --git a/rag/flow/base.py b/rag/flow/base.py index 4b256e78f..03005dc03 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import time from functools import partial from typing import Any -import trio from agent.component.base import ComponentBase, ComponentParamBase from common.connection_utils import timeout @@ -43,9 +43,11 @@ class ProcessBase(ComponentBase): for k, v in kwargs.items(): self.set_output(k, v) try: - with trio.fail_after(self._param.timeout): - await self._invoke(**kwargs) - self.callback(1, "Done") + await asyncio.wait_for( + self._invoke(**kwargs), + timeout=self._param.timeout + ) + self.callback(1, "Done") except Exception as e: if self.get_exception_default_value(): self.set_exception_default_value() diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index ca0400a34..5b8e3483a 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import random import re from copy import deepcopy from functools import partial -import trio - from common.misc_utils import get_uuid from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser @@ -178,9 +177,17 @@ class HierarchicalMerger(ProcessBase): } for c, img in zip(cks, images) ] - async with trio.open_nursery() as nursery: - for d in cks: - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in cks: + tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 7747448ad..1d1c199aa 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import io import json import os @@ -20,7 +21,6 @@ import re from functools import partial import numpy as np -import trio from PIL import Image from api.db.services.file2document_service import File2DocumentService @@ -804,7 +804,7 @@ class Parser(ProcessBase): for p_type, conf in self._param.setups.items(): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): continue - await trio.to_thread.run_sync(function_map[p_type], name, blob) + await asyncio.to_thread(function_map[p_type], name, blob) done = True break @@ -812,6 +812,14 @@ class Parser(ProcessBase): raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower()) outs = self.output() - async with trio.open_nursery() as nursery: - for d in outs.get("json", []): - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in outs.get("json", []): + tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid()))) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise \ No newline at end of file diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index b44c77bd4..cc4bed0fa 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import logging import random from timeit import default_timer as timer -import trio from agent.canvas import Graph from api.db.services.document_service import DocumentService from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID @@ -152,8 +152,9 @@ class Pipeline(Graph): #else: # cpn_obj.invoke(**last_cpn.output()) - async with trio.open_nursery() as nursery: - nursery.start_soon(invoke) + tasks = [] + tasks.append(asyncio.create_task(invoke())) + await asyncio.gather(*tasks) if cpn_obj.error(): self.error = "[ERROR]" + cpn_obj.error() diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 1ef06839d..851d880d4 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import random import re from copy import deepcopy from functools import partial -import trio from common.misc_utils import get_uuid from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser @@ -129,9 +129,16 @@ class Splitter(ProcessBase): } for c, img in zip(chunks, images) if c.strip() ] - async with trio.open_nursery() as nursery: - for d in cks: - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in cks: + tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if custom_pattern: docs = [] diff --git a/rag/flow/tests/client.py b/rag/flow/tests/client.py index 0b7612816..16d6fd0bf 100644 --- a/rag/flow/tests/client.py +++ b/rag/flow/tests/client.py @@ -14,13 +14,12 @@ # limitations under the License. # import argparse +import asyncio import json import os import time from concurrent.futures import ThreadPoolExecutor -import trio - from common import settings from rag.flow.pipeline import Pipeline @@ -57,5 +56,5 @@ if __name__ == "__main__": # queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0) - trio.run(pipeline.run) + asyncio.run(pipeline.run()) thr.result() diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index 965cb4c1e..a13d95c0a 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import random import re import numpy as np -import trio from common.constants import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -84,7 +84,7 @@ class Tokenizer(ProcessBase): cnts_ = np.array([]) for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],) if len(cnts_) == 0: cnts_ = vts else: diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index e8ad77032..9abaa8a94 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -22,7 +22,6 @@ from copy import deepcopy from typing import Tuple import jinja2 import json_repair -import trio from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer from rag.prompts.template import load_prompt @@ -744,12 +743,19 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): titles = [] chunks_res = [] - async with trio.open_nursery() as nursery: - for i, chunk in enumerate(chunk_sections): - if not chunk: - continue - chunks_res.append({"chunks": chunk}) - nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback) + tasks = [] + for i, chunk in enumerate(chunk_sections): + if not chunk: + continue + chunks_res.append({"chunks": chunk}) + tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise for chunk in chunks_res: titles.extend(chunk.get("toc", [])) diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 4349b6f55..bd9c42f85 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -19,6 +19,7 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code +import asyncio import copy import faulthandler import logging @@ -31,8 +32,6 @@ import traceback from datetime import datetime, timezone from typing import Any -import trio - from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService from common import settings @@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) -task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) +task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) class SyncBase: @@ -60,75 +59,102 @@ class SyncBase: async def __call__(self, task: dict): SyncLogsService.start(task["id"], task["connector_id"]) - try: - async with task_limiter: - with trio.fail_after(task["timeout_secs"]): - document_batch_generator = await self._generate(task) - doc_num = 0 - next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) - if task["poll_range_start"]: - next_update = task["poll_range_start"] - - failed_docs = 0 - for document_batch in document_batch_generator: - if not document_batch: - continue - min_update = min([doc.doc_updated_at for doc in document_batch]) - max_update = max([doc.doc_updated_at for doc in document_batch]) - next_update = max([next_update, max_update]) - docs = [] - for doc in document_batch: - doc_dict = { - "id": doc.id, - "connector_id": task["connector_id"], - "source": self.SOURCE_NAME, - "semantic_identifier": doc.semantic_identifier, - "extension": doc.extension, - "size_bytes": doc.size_bytes, - "doc_updated_at": doc.doc_updated_at, - "blob": doc.blob, - } - # Add metadata if present - if doc.metadata: - doc_dict["metadata"] = doc.metadata - docs.append(doc_dict) - try: - e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"]) - SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) - doc_num += len(docs) - except Exception as batch_ex: - error_msg = str(batch_ex) - error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None - - if error_code == 1267 or "collation" in error_msg.lower(): - logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)") - for doc in docs: - logging.debug(f"Skipped: {doc['semantic_identifier']}") - else: - logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}") - - failed_docs += len(docs) - continue + async with task_limiter: + try: + await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"]) - prefix = self._get_source_prefix() - if failed_docs > 0: - logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") - else: - logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}") - SyncLogsService.done(task["id"], task["connector_id"]) - task["poll_range_start"] = next_update + except asyncio.TimeoutError: + msg = f"Task timeout after {task['timeout_secs']} seconds" + SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg}) + return - except Exception as ex: - msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()]) - SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)}) + except Exception as ex: + msg = "\n".join([ + "".join(traceback.format_exception_only(None, ex)).strip(), + "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(), + ]) + SyncLogsService.update_by_id(task["id"], { + "status": TaskStatus.FAIL, + "full_exception_trace": msg, + "error_msg": str(ex) + }) + return SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) + async def _run_task_logic(self, task: dict): + document_batch_generator = await self._generate(task) + + doc_num = 0 + failed_docs = 0 + next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) + + if task["poll_range_start"]: + next_update = task["poll_range_start"] + + async for document_batch in document_batch_generator: # 如果是 async generator + if not document_batch: + continue + + min_update = min(doc.doc_updated_at for doc in document_batch) + max_update = max(doc.doc_updated_at for doc in document_batch) + next_update = max(next_update, max_update) + + docs = [] + for doc in document_batch: + d = { + "id": doc.id, + "connector_id": task["connector_id"], + "source": self.SOURCE_NAME, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob, + } + if doc.metadata: + d["metadata"] = doc.metadata + docs.append(d) + + try: + e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) + err, dids = SyncLogsService.duplicate_and_parse( + kb, docs, task["tenant_id"], + f"{self.SOURCE_NAME}/{task['connector_id']}", + task["auto_parse"] + ) + SyncLogsService.increase_docs( + task["id"], min_update, max_update, + len(docs), "\n".join(err), len(err) + ) + + doc_num += len(docs) + + except Exception as batch_ex: + msg = str(batch_ex) + code = getattr(batch_ex, "args", [None])[0] + + if code == 1267 or "collation" in msg.lower(): + logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict") + else: + logging.error(f"Error processing batch: {msg}") + + failed_docs += len(docs) + continue + + prefix = self._get_source_prefix() + if failed_docs > 0: + logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") + else: + logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}") + + SyncLogsService.done(task["id"], task["connector_id"]) + task["poll_range_start"] = next_update + async def _generate(self, task: dict): raise NotImplementedError - + def _get_source_prefix(self): return "" @@ -617,23 +643,32 @@ func_factory = { async def dispatch_tasks(): - async with trio.open_nursery() as nursery: - while True: - try: - list(SyncLogsService.list_sync_tasks()[0]) - break - except Exception as e: - logging.warning(f"DB is not ready yet: {e}") - await trio.sleep(3) + while True: + try: + list(SyncLogsService.list_sync_tasks()[0]) + break + except Exception as e: + logging.warning(f"DB is not ready yet: {e}") + await asyncio.sleep(3) - for task in SyncLogsService.list_sync_tasks()[0]: - if task["poll_range_start"]: - task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) - if task["poll_range_end"]: - task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) - func = func_factory[task["source"]](task["config"]) - nursery.start_soon(func, task) - await trio.sleep(1) + tasks = [] + for task in SyncLogsService.list_sync_tasks()[0]: + if task["poll_range_start"]: + task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) + if task["poll_range_end"]: + task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) + + func = func_factory[task["source"]](task["config"]) + tasks.append(asyncio.create_task(func(task))) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + await asyncio.sleep(1) stop_event = threading.Event() @@ -678,4 +713,4 @@ async def main(): if __name__ == "__main__": faulthandler.enable() init_root_logger(CONSUMER_NAME) - trio.run(main) + asyncio.run(main) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 62693f24f..c1776c5c3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import socket import concurrent # from beartype import BeartypeConf @@ -46,7 +47,6 @@ from functools import partial from multiprocessing.context import TimeoutError from timeit import default_timer as timer import signal -import trio import exceptiongroup import faulthandler import numpy as np @@ -114,11 +114,11 @@ CURRENT_TASKS = {} MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) -task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) -chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) -embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) -minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) -kg_limiter = trio.CapacityLimiter(2) +task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) +chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS) +embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS) +minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO) +kg_limiter = asyncio.Semaphore(2) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) stop_event = threading.Event() @@ -219,7 +219,7 @@ async def collect(): async def get_storage_binary(bucket, name): - return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name)) + return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name) @timeout(60*80, 1) @@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback): try: async with chunk_limiter: - cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], - to_page=task["to_page"], lang=task["language"], callback=progress_callback, - kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) + cks = await asyncio.to_thread( + chunker.chunk, + task["name"], + binary=binary, + from_page=task["from_page"], + to_page=task["to_page"], + lang=task["language"], + callback=progress_callback, + kb_id=task["kb_id"], + parser_config=task["parser_config"], + tenant_id=task["tenant_id"], + ) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) except TaskCanceledException: raise @@ -290,9 +299,16 @@ async def build_chunks(task, progress_callback): "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) raise - async with trio.open_nursery() as nursery: - for ck in cks: - nursery.start_soon(upload_to_minio, doc, ck) + tasks = [] + for ck in cks: + tasks.append(asyncio.create_task(upload_to_minio(doc, ck))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise el = timer() - st logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el)) @@ -306,15 +322,27 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) if not cached: async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn)) + cached = await asyncio.to_thread( + keyword_extraction, + chat_mdl, + d["content_with_weight"], + topn, + ) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) if cached: d["important_kwd"] = cached.split(",") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) return - async with trio.open_nursery() as nursery: - for d in docs: - nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) + tasks = [] + for d in docs: + tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["parser_config"].get("auto_questions", 0): @@ -326,14 +354,26 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) if not cached: async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn)) + cached = await asyncio.to_thread( + question_proposal, + chat_mdl, + d["content_with_weight"], + topn, + ) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) if cached: d["question_kwd"] = cached.split("\n") d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) - async with trio.open_nursery() as nursery: - for d in docs: - nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) + tasks = [] + for d in docs: + tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["kb_parser_config"].get("tag_kb_ids", []): @@ -371,15 +411,29 @@ async def build_chunks(task, progress_callback): if not picked_examples: picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)) + cached = await asyncio.to_thread( + content_tagging, + chat_mdl, + d["content_with_weight"], + all_tags, + picked_examples, + topn_tags, + ) if cached: cached = json.dumps(cached) if cached: set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) d[TAG_FLD] = json.loads(cached) - async with trio.open_nursery() as nursery: - for d in docs_to_tag: - nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) + tasks = [] + for d in docs_to_tag: + tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) return docs @@ -392,7 +446,7 @@ def build_TOC(task, docs, progress_callback): d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) )) - toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback) + toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) ii = 0 while ii < len(toc): @@ -440,7 +494,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count = 0 if len(tts) == len(cnts): - vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) + vts, c = await asyncio.to_thread(mdl.encode, tts[0:1]) tts = np.tile(vts[0], (len(cnts), 1)) tk_count += c @@ -452,7 +506,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE]) if len(cnts_) == 0: cnts_ = vts else: @@ -535,7 +589,7 @@ async def run_dataflow(task: dict): prog = 0.8 for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE]) if len(vects) == 0: vects = vts else: @@ -742,14 +796,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c mothers.append(mom_ck) for b in range(0, len(mothers), settings.DOC_BULK_SIZE): - await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return False for b in range(0, len(chunks), settings.DOC_BULK_SIZE): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -766,10 +820,17 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c TaskService.update_chunk_ids(task_id, chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) - async with trio.open_nursery() as nursery: - for chunk_id in chunk_ids: - nursery.start_soon(delete_image, task_dataset_id, chunk_id) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,) + tasks = [] + for chunk_id in chunk_ids: + tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") return False return True @@ -994,7 +1055,7 @@ async def handle_task(): global DONE_TASKS, FAILED_TASKS redis_msg, task = await collect() if not task: - await trio.sleep(5) + await asyncio.sleep(5) return task_type = task["task_type"] @@ -1091,7 +1152,7 @@ async def report_status(): logging.exception("report_status got exception") finally: redis_lock.release() - await trio.sleep(30) + await asyncio.sleep(30) async def task_manager(): @@ -1127,14 +1188,22 @@ async def main(): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - async with trio.open_nursery() as nursery: - nursery.start_soon(report_status) + report_task = asyncio.create_task(report_status()) + tasks = [] + try: while not stop_event.is_set(): await task_limiter.acquire() - nursery.start_soon(task_manager) + t = asyncio.create_task(task_manager()) + tasks.append(t) + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + report_task.cancel() + await asyncio.gather(report_task, return_exceptions=True) logging.error("BUG!!! You should not reach here!!!") if __name__ == "__main__": faulthandler.enable() init_root_logger(CONSUMER_NAME) - trio.run(main) + asyncio.run(main()) diff --git a/rag/utils/base64_image.py b/rag/utils/base64_image.py index 15794944c..66c90dfa5 100644 --- a/rag/utils/base64_image.py +++ b/rag/utils/base64_image.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import base64 import logging from functools import partial @@ -24,39 +25,53 @@ from PIL import Image test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image = base64.b64decode(test_image_base64) - async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"): import logging from io import BytesIO - import trio from rag.svr.task_executor import minio_limiter + if "image" not in d: return if not d["image"]: del d["image"] return - with BytesIO() as output_buffer: - if isinstance(d["image"], bytes): - output_buffer.write(d["image"]) - output_buffer.seek(0) - else: - # If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format. - if d["image"].mode in ("RGBA", "P"): - converted_image = d["image"].convert("RGB") - d["image"] = converted_image - try: - d["image"].save(output_buffer, format='JPEG') - except OSError as e: - logging.warning( - "Saving image exception, ignore: {}".format(str(e))) + def encode_image(): + with BytesIO() as buf: + img = d["image"] - async with minio_limiter: - await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue())) - d["img_id"] = f"{bucket}-{objname}" - if not isinstance(d["image"], bytes): - d["image"].close() - del d["image"] # Remove image reference + if isinstance(img, bytes): + buf.write(img) + buf.seek(0) + return buf.getvalue() + + if img.mode in ("RGBA", "P"): + img = img.convert("RGB") + + try: + img.save(buf, format="JPEG") + except OSError as e: + logging.warning(f"Saving image exception: {e}") + return None + + buf.seek(0) + return buf.getvalue() + + jpeg_binary = await asyncio.to_thread(encode_image) + if jpeg_binary is None: + del d["image"] + return + + async with minio_limiter: + await asyncio.to_thread( + lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary) + ) + + d["img_id"] = f"{bucket}-{objname}" + + if not isinstance(d["image"], bytes): + d["image"].close() + del d["image"] def id2image(image_id:str|None, storage_get_func: partial): diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index b7cc15c63..5a8aece1d 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import logging import json import uuid @@ -22,7 +23,6 @@ import valkey as redis from common.decorator import singleton from common import settings from valkey.lock import Lock -import trio REDIS = {} try: @@ -405,7 +405,7 @@ class RedisDistributedLock: while True: if self.lock.acquire(token=self.lock_value): break - await trio.sleep(10) + await asyncio.sleep(10) def release(self): REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)