Refa:replace trio with asyncio in agent/deepdoc/rag

This commit is contained in:
buua436 2025-12-09 13:03:37 +08:00
parent 1777620ea5
commit 94dccd8270
15 changed files with 382 additions and 207 deletions

View file

@ -24,7 +24,6 @@ import os
import logging import logging
from typing import Any, List, Union from typing import Any, List, Union
import pandas as pd import pandas as pd
import trio
from agent import settings from agent import settings
from common.connection_utils import timeout from common.connection_utils import timeout
@ -393,7 +392,7 @@ class ComponentParamBase(ABC):
class ComponentBase(ABC): class ComponentBase(ABC):
component_name: str component_name: str
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
def __str__(self): def __str__(self):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import math import math
import os import os
@ -28,7 +29,6 @@ from timeit import default_timer as timer
import numpy as np import numpy as np
import pdfplumber import pdfplumber
import trio
import xgboost as xgb import xgboost as xgb
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
@ -66,7 +66,7 @@ class RAGFlowPdfParser:
self.ocr = OCR() self.ocr = OCR()
self.parallel_limiter = None self.parallel_limiter = None
if settings.PARALLEL_DEVICES > 1: if settings.PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)] self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
if layout_recognizer_type not in ["onnx", "ascend"]: if layout_recognizer_type not in ["onnx", "ascend"]:
@ -383,7 +383,7 @@ class RAGFlowPdfParser:
else: else:
x0s.append([x]) x0s.append([x])
x0s = np.array(x0s, dtype=float) x0s = np.array(x0s, dtype=float)
max_try = min(4, len(bxs)) max_try = min(4, len(bxs))
if max_try < 2: if max_try < 2:
max_try = 1 max_try = 1
@ -417,7 +417,7 @@ class RAGFlowPdfParser:
for pg, bxs in by_page.items(): for pg, bxs in by_page.items():
if not bxs: if not bxs:
continue continue
k = page_cols[pg] k = page_cols[pg]
if len(bxs) < k: if len(bxs) < k:
k = 1 k = 1
x0s = np.array([[b["x0"]] for b in bxs], dtype=float) x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
@ -431,7 +431,7 @@ class RAGFlowPdfParser:
for b, lb in zip(bxs, labels): for b, lb in zip(bxs, labels):
b["col_id"] = remap[lb] b["col_id"] = remap[lb]
grouped = defaultdict(list) grouped = defaultdict(list)
for b in bxs: for b in bxs:
grouped[b["col_id"]].append(b) grouped[b["col_id"]].append(b)
@ -1112,7 +1112,7 @@ class RAGFlowPdfParser:
if limiter: if limiter:
async with limiter: async with limiter:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
else: else:
self.__ocr(i + 1, img, chars, zoomin, id) self.__ocr(i + 1, img, chars, zoomin, id)
@ -1128,12 +1128,33 @@ class RAGFlowPdfParser:
return chars return chars
if self.parallel_limiter: if self.parallel_limiter:
async with trio.open_nursery() as nursery: tasks = []
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES]
async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
await __img_ocr(
i,
i % settings.PARALLEL_DEVICES,
img,
chars,
semaphore,
)
tasks.append(asyncio.create_task(wrapper()))
await asyncio.sleep(0)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
await trio.sleep(0.1)
else: else:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() chars = __ocr_preprocess()
@ -1141,7 +1162,7 @@ class RAGFlowPdfParser:
start = timer() start = timer()
trio.run(__img_ocr_launcher) asyncio.run(__img_ocr_launcher())
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import os import os
import sys import sys
sys.path.insert( sys.path.insert(
@ -28,7 +29,6 @@ from deepdoc.vision.seeit import draw_box
from deepdoc.vision import OCR, init_in_out from deepdoc.vision import OCR, init_in_out
import argparse import argparse
import numpy as np import numpy as np
import trio
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
@ -39,7 +39,7 @@ def main(args):
import torch.cuda import torch.cuda
cuda_devices = torch.cuda.device_count() cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR() ocr = OCR()
images, outputs = init_in_out(args) images, outputs = init_in_out(args)
@ -62,22 +62,28 @@ def main(args):
async def __ocr_thread(i, id, img, limiter = None): async def __ocr_thread(i, id, img, limiter = None):
if limiter: if limiter:
async with limiter: async with limiter:
print("Task {} use device {}".format(i, id)) print(f"Task {i} use device {id}")
await trio.to_thread.run_sync(lambda: __ocr(i, id, img)) await asyncio.to_thread(__ocr, i, id, img)
else: else:
__ocr(i, id, img) await asyncio.to_thread(__ocr, i, id, img)
async def __ocr_launcher(): async def __ocr_launcher():
if cuda_devices > 1: tasks = []
async with trio.open_nursery() as nursery: for i, img in enumerate(images):
for i, img in enumerate(images): dev_id = i % cuda_devices if cuda_devices > 1 else 0
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices]) semaphore = limiter[dev_id] if limiter else None
await trio.sleep(0.1) tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore)))
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)
trio.run(__ocr_launcher) try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
asyncio.run(__ocr_launcher())
print("OCR tasks are all done") print("OCR tasks are all done")

View file

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import os import os
import time import time
from functools import partial from functools import partial
from typing import Any from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from common.connection_utils import timeout from common.connection_utils import timeout
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
for k, v in kwargs.items(): for k, v in kwargs.items():
self.set_output(k, v) self.set_output(k, v)
try: try:
with trio.fail_after(self._param.timeout): await asyncio.wait_for(
await self._invoke(**kwargs) self._invoke(**kwargs),
self.callback(1, "Done") timeout=self._param.timeout
)
self.callback(1, "Done")
except Exception as e: except Exception as e:
if self.get_exception_default_value(): if self.get_exception_default_value():
self.set_exception_default_value() self.set_exception_default_value()

View file

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import random import random
import re import re
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -178,9 +177,17 @@ class HierarchicalMerger(ProcessBase):
} }
for c, img in zip(cks, images) for c, img in zip(cks, images)
] ]
async with trio.open_nursery() as nursery: tasks = []
for d in cks: for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
self.set_output("chunks", cks) self.set_output("chunks", cks)
self.callback(1, "Done.") self.callback(1, "Done.")

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import io import io
import json import json
import os import os
@ -20,7 +21,6 @@ import re
from functools import partial from functools import partial
import numpy as np import numpy as np
import trio
from PIL import Image from PIL import Image
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
@ -804,7 +804,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items(): for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue continue
await trio.to_thread.run_sync(function_map[p_type], name, blob) await asyncio.to_thread(function_map[p_type], name, blob)
done = True done = True
break break
@ -812,6 +812,14 @@ class Parser(ProcessBase):
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower()) raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output() outs = self.output()
async with trio.open_nursery() as nursery: tasks = []
for d in outs.get("json", []): for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise

View file

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import logging import logging
import random import random
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
from agent.canvas import Graph from agent.canvas import Graph
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
@ -152,8 +152,9 @@ class Pipeline(Graph):
#else: #else:
# cpn_obj.invoke(**last_cpn.output()) # cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery: tasks = []
nursery.start_soon(invoke) tasks.append(asyncio.create_task(invoke()))
await asyncio.gather(*tasks)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()

View file

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import random import random
import re import re
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -129,9 +129,16 @@ class Splitter(ProcessBase):
} }
for c, img in zip(chunks, images) if c.strip() for c, img in zip(chunks, images) if c.strip()
] ]
async with trio.open_nursery() as nursery: tasks = []
for d in cks: for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if custom_pattern: if custom_pattern:
docs = [] docs = []

View file

@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import argparse import argparse
import asyncio
import json import json
import os import os
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import trio
from common import settings from common import settings
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
@ -57,5 +56,5 @@ if __name__ == "__main__":
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0) # queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run) asyncio.run(pipeline.run())
thr.result() thr.result()

View file

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import random import random
import re import re
import numpy as np import numpy as np
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:

View file

@ -22,7 +22,6 @@ from copy import deepcopy
from typing import Tuple from typing import Tuple
import jinja2 import jinja2
import json_repair import json_repair
import trio
from common.misc_utils import hash_str2int from common.misc_utils import hash_str2int
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
@ -744,12 +743,19 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
titles = [] titles = []
chunks_res = [] chunks_res = []
async with trio.open_nursery() as nursery: tasks = []
for i, chunk in enumerate(chunk_sections): for i, chunk in enumerate(chunk_sections):
if not chunk: if not chunk:
continue continue
chunks_res.append({"chunks": chunk}) chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback) tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
for chunk in chunks_res: for chunk in chunks_res:
titles.extend(chunk.get("toc", [])) titles.extend(chunk.get("toc", []))

View file

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import asyncio
import copy import copy
import faulthandler import faulthandler
import logging import logging
@ -31,8 +32,6 @@ import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
import trio
from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings from common import settings
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
class SyncBase: class SyncBase:
@ -60,75 +59,102 @@ class SyncBase:
async def __call__(self, task: dict): async def __call__(self, task: dict):
SyncLogsService.start(task["id"], task["connector_id"]) SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
failed_docs = 0
for document_batch in document_batch_generator:
if not document_batch:
continue
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = []
for doc in document_batch:
doc_dict = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
# Add metadata if present
if doc.metadata:
doc_dict["metadata"] = doc.metadata
docs.append(doc_dict)
try: async with task_limiter:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) try:
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"]) await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
except Exception as batch_ex:
error_msg = str(batch_ex)
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
if error_code == 1267 or "collation" in error_msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
for doc in docs:
logging.debug(f"Skipped: {doc['semantic_identifier']}")
else:
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix() except asyncio.TimeoutError:
if failed_docs > 0: msg = f"Task timeout after {task['timeout_secs']} seconds"
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
else: return
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex: except Exception as ex:
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()]) msg = "\n".join([
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)}) "".join(traceback.format_exception_only(None, ex)).strip(),
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
])
SyncLogsService.update_by_id(task["id"], {
"status": TaskStatus.FAIL,
"full_exception_trace": msg,
"error_msg": str(ex)
})
return
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run_task_logic(self, task: dict):
document_batch_generator = await self._generate(task)
doc_num = 0
failed_docs = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
async for document_batch in document_batch_generator: # 如果是 async generator
if not document_batch:
continue
min_update = min(doc.doc_updated_at for doc in document_batch)
max_update = max(doc.doc_updated_at for doc in document_batch)
next_update = max(next_update, max_update)
docs = []
for doc in document_batch:
d = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
if doc.metadata:
d["metadata"] = doc.metadata
docs.append(d)
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(
kb, docs, task["tenant_id"],
f"{self.SOURCE_NAME}/{task['connector_id']}",
task["auto_parse"]
)
SyncLogsService.increase_docs(
task["id"], min_update, max_update,
len(docs), "\n".join(err), len(err)
)
doc_num += len(docs)
except Exception as batch_ex:
msg = str(batch_ex)
code = getattr(batch_ex, "args", [None])[0]
if code == 1267 or "collation" in msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
else:
logging.error(f"Error processing batch: {msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
async def _generate(self, task: dict): async def _generate(self, task: dict):
raise NotImplementedError raise NotImplementedError
def _get_source_prefix(self): def _get_source_prefix(self):
return "" return ""
@ -617,23 +643,32 @@ func_factory = {
async def dispatch_tasks(): async def dispatch_tasks():
async with trio.open_nursery() as nursery: while True:
while True: try:
try: list(SyncLogsService.list_sync_tasks()[0])
list(SyncLogsService.list_sync_tasks()[0]) break
break except Exception as e:
except Exception as e: logging.warning(f"DB is not ready yet: {e}")
logging.warning(f"DB is not ready yet: {e}") await asyncio.sleep(3)
await trio.sleep(3)
for task in SyncLogsService.list_sync_tasks()[0]: tasks = []
if task["poll_range_start"]: for task in SyncLogsService.list_sync_tasks()[0]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) if task["poll_range_start"]:
if task["poll_range_end"]: task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) if task["poll_range_end"]:
func = func_factory[task["source"]](task["config"]) task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
nursery.start_soon(func, task)
await trio.sleep(1) func = func_factory[task["source"]](task["config"])
tasks.append(asyncio.create_task(func(task)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
await asyncio.sleep(1)
stop_event = threading.Event() stop_event = threading.Event()
@ -678,4 +713,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)
trio.run(main) asyncio.run(main)

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import socket import socket
import concurrent import concurrent
# from beartype import BeartypeConf # from beartype import BeartypeConf
@ -46,7 +47,6 @@ from functools import partial
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import signal import signal
import trio
import exceptiongroup import exceptiongroup
import faulthandler import faulthandler
import numpy as np import numpy as np
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
kg_limiter = trio.CapacityLimiter(2) kg_limiter = asyncio.Semaphore(2)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event() stop_event = threading.Event()
@ -219,7 +219,7 @@ async def collect():
async def get_storage_binary(bucket, name): async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name)) return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1) @timeout(60*80, 1)
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
try: try:
async with chunk_limiter: async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], cks = await asyncio.to_thread(
to_page=task["to_page"], lang=task["language"], callback=progress_callback, chunker.chunk,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) task["name"],
binary=binary,
from_page=task["from_page"],
to_page=task["to_page"],
lang=task["language"],
callback=progress_callback,
kb_id=task["kb_id"],
parser_config=task["parser_config"],
tenant_id=task["tenant_id"],
)
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: except TaskCanceledException:
raise raise
@ -290,9 +299,16 @@ async def build_chunks(task, progress_callback):
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise raise
async with trio.open_nursery() as nursery: tasks = []
for ck in cks: for ck in cks:
nursery.start_soon(upload_to_minio, doc, ck) tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
el = timer() - st el = timer() - st
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el)) logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
@ -306,15 +322,27 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn)) cached = await asyncio.to_thread(
keyword_extraction,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached: if cached:
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return return
async with trio.open_nursery() as nursery: tasks = []
for d in docs: for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0): if task["parser_config"].get("auto_questions", 0):
@ -326,14 +354,26 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn)) cached = await asyncio.to_thread(
question_proposal,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
async with trio.open_nursery() as nursery: tasks = []
for d in docs: for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []): if task["kb_parser_config"].get("tag_kb_ids", []):
@ -371,15 +411,29 @@ async def build_chunks(task, progress_callback):
if not picked_examples: if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)) cached = await asyncio.to_thread(
content_tagging,
chat_mdl,
d["content_with_weight"],
all_tags,
picked_examples,
topn_tags,
)
if cached: if cached:
cached = json.dumps(cached) cached = json.dumps(cached)
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached) d[TAG_FLD] = json.loads(cached)
async with trio.open_nursery() as nursery: tasks = []
for d in docs_to_tag: for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs return docs
@ -392,7 +446,7 @@ def build_TOC(task, docs, progress_callback):
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
)) ))
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback) toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0 ii = 0
while ii < len(toc): while ii < len(toc):
@ -440,7 +494,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0 tk_count = 0
if len(tts) == len(cnts): if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1)) tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c tk_count += c
@ -452,7 +506,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:
@ -535,7 +589,7 @@ async def run_dataflow(task: dict):
prog = 0.8 prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0: if len(vects) == 0:
vects = vts vects = vts
else: else:
@ -742,14 +796,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mothers.append(mom_ck) mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE): for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return False return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE): for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -766,10 +820,17 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
async with trio.open_nursery() as nursery: tasks = []
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id) tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return False return False
return True return True
@ -994,7 +1055,7 @@ async def handle_task():
global DONE_TASKS, FAILED_TASKS global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect() redis_msg, task = await collect()
if not task: if not task:
await trio.sleep(5) await asyncio.sleep(5)
return return
task_type = task["task_type"] task_type = task["task_type"]
@ -1091,7 +1152,7 @@ async def report_status():
logging.exception("report_status got exception") logging.exception("report_status got exception")
finally: finally:
redis_lock.release() redis_lock.release()
await trio.sleep(30) await asyncio.sleep(30)
async def task_manager(): async def task_manager():
@ -1127,14 +1188,22 @@ async def main():
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
async with trio.open_nursery() as nursery: report_task = asyncio.create_task(report_status())
nursery.start_soon(report_status) tasks = []
try:
while not stop_event.is_set(): while not stop_event.is_set():
await task_limiter.acquire() await task_limiter.acquire()
nursery.start_soon(task_manager) t = asyncio.create_task(task_manager())
tasks.append(t)
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
report_task.cancel()
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!") logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)
trio.run(main) asyncio.run(main())

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import base64 import base64
import logging import logging
from functools import partial from functools import partial
@ -24,39 +25,53 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"): async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO
import trio
from rag.svr.task_executor import minio_limiter from rag.svr.task_executor import minio_limiter
if "image" not in d: if "image" not in d:
return return
if not d["image"]: if not d["image"]:
del d["image"] del d["image"]
return return
with BytesIO() as output_buffer: def encode_image():
if isinstance(d["image"], bytes): with BytesIO() as buf:
output_buffer.write(d["image"]) img = d["image"]
output_buffer.seek(0)
else:
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
if d["image"].mode in ("RGBA", "P"):
converted_image = d["image"].convert("RGB")
d["image"] = converted_image
try:
d["image"].save(output_buffer, format='JPEG')
except OSError as e:
logging.warning(
"Saving image exception, ignore: {}".format(str(e)))
async with minio_limiter: if isinstance(img, bytes):
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue())) buf.write(img)
d["img_id"] = f"{bucket}-{objname}" buf.seek(0)
if not isinstance(d["image"], bytes): return buf.getvalue()
d["image"].close()
del d["image"] # Remove image reference if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
try:
img.save(buf, format="JPEG")
except OSError as e:
logging.warning(f"Saving image exception: {e}")
return None
buf.seek(0)
return buf.getvalue()
jpeg_binary = await asyncio.to_thread(encode_image)
if jpeg_binary is None:
del d["image"]
return
async with minio_limiter:
await asyncio.to_thread(
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
)
d["img_id"] = f"{bucket}-{objname}"
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"]
def id2image(image_id:str|None, storage_get_func: partial): def id2image(image_id:str|None, storage_get_func: partial):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import json import json
import uuid import uuid
@ -22,7 +23,6 @@ import valkey as redis
from common.decorator import singleton from common.decorator import singleton
from common import settings from common import settings
from valkey.lock import Lock from valkey.lock import Lock
import trio
REDIS = {} REDIS = {}
try: try:
@ -405,7 +405,7 @@ class RedisDistributedLock:
while True: while True:
if self.lock.acquire(token=self.lock_value): if self.lock.acquire(token=self.lock_value):
break break
await trio.sleep(10) await asyncio.sleep(10)
def release(self): def release(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value) REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)