Refa:replace trio with asyncio in agent/deepdoc/rag
This commit is contained in:
parent
1777620ea5
commit
94dccd8270
15 changed files with 382 additions and 207 deletions
|
|
@ -24,7 +24,6 @@ import os
|
|||
import logging
|
||||
from typing import Any, List, Union
|
||||
import pandas as pd
|
||||
import trio
|
||||
from agent import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
|
@ -393,7 +392,7 @@ class ComponentParamBase(ABC):
|
|||
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
|
@ -28,7 +29,6 @@ from timeit import default_timer as timer
|
|||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import trio
|
||||
import xgboost as xgb
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
|
@ -66,7 +66,7 @@ class RAGFlowPdfParser:
|
|||
self.ocr = OCR()
|
||||
self.parallel_limiter = None
|
||||
if settings.PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)]
|
||||
self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)]
|
||||
|
||||
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||
if layout_recognizer_type not in ["onnx", "ascend"]:
|
||||
|
|
@ -383,7 +383,7 @@ class RAGFlowPdfParser:
|
|||
else:
|
||||
x0s.append([x])
|
||||
x0s = np.array(x0s, dtype=float)
|
||||
|
||||
|
||||
max_try = min(4, len(bxs))
|
||||
if max_try < 2:
|
||||
max_try = 1
|
||||
|
|
@ -417,7 +417,7 @@ class RAGFlowPdfParser:
|
|||
for pg, bxs in by_page.items():
|
||||
if not bxs:
|
||||
continue
|
||||
k = page_cols[pg]
|
||||
k = page_cols[pg]
|
||||
if len(bxs) < k:
|
||||
k = 1
|
||||
x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
|
||||
|
|
@ -431,7 +431,7 @@ class RAGFlowPdfParser:
|
|||
|
||||
for b, lb in zip(bxs, labels):
|
||||
b["col_id"] = remap[lb]
|
||||
|
||||
|
||||
grouped = defaultdict(list)
|
||||
for b in bxs:
|
||||
grouped[b["col_id"]].append(b)
|
||||
|
|
@ -1112,7 +1112,7 @@ class RAGFlowPdfParser:
|
|||
|
||||
if limiter:
|
||||
async with limiter:
|
||||
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
|
||||
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
else:
|
||||
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||
|
||||
|
|
@ -1128,12 +1128,33 @@ class RAGFlowPdfParser:
|
|||
return chars
|
||||
|
||||
if self.parallel_limiter:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
tasks = []
|
||||
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES]
|
||||
|
||||
async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
|
||||
await __img_ocr(
|
||||
i,
|
||||
i % settings.PARALLEL_DEVICES,
|
||||
img,
|
||||
chars,
|
||||
semaphore,
|
||||
)
|
||||
|
||||
tasks.append(asyncio.create_task(wrapper()))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
|
@ -1141,7 +1162,7 @@ class RAGFlowPdfParser:
|
|||
|
||||
start = timer()
|
||||
|
||||
trio.run(__img_ocr_launcher)
|
||||
asyncio.run(__img_ocr_launcher())
|
||||
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(
|
||||
|
|
@ -28,7 +29,6 @@ from deepdoc.vision.seeit import draw_box
|
|||
from deepdoc.vision import OCR, init_in_out
|
||||
import argparse
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
||||
|
|
@ -39,7 +39,7 @@ def main(args):
|
|||
import torch.cuda
|
||||
|
||||
cuda_devices = torch.cuda.device_count()
|
||||
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||
limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||
ocr = OCR()
|
||||
images, outputs = init_in_out(args)
|
||||
|
||||
|
|
@ -62,22 +62,28 @@ def main(args):
|
|||
async def __ocr_thread(i, id, img, limiter = None):
|
||||
if limiter:
|
||||
async with limiter:
|
||||
print("Task {} use device {}".format(i, id))
|
||||
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
|
||||
print(f"Task {i} use device {id}")
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
else:
|
||||
__ocr(i, id, img)
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
|
||||
|
||||
async def __ocr_launcher():
|
||||
if cuda_devices > 1:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, img in enumerate(images):
|
||||
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(images):
|
||||
await __ocr_thread(i, 0, img)
|
||||
tasks = []
|
||||
for i, img in enumerate(images):
|
||||
dev_id = i % cuda_devices if cuda_devices > 1 else 0
|
||||
semaphore = limiter[dev_id] if limiter else None
|
||||
tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore)))
|
||||
|
||||
trio.run(__ocr_launcher)
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
asyncio.run(__ocr_launcher())
|
||||
|
||||
print("OCR tasks are all done")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import trio
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
|
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
|
|||
for k, v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
await asyncio.wait_for(
|
||||
self._invoke(**kwargs),
|
||||
timeout=self._param.timeout
|
||||
)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
|
|
|
|||
|
|
@ -13,13 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
|
|
@ -178,9 +177,17 @@ class HierarchicalMerger(ProcessBase):
|
|||
}
|
||||
for c, img in zip(cks, images)
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in cks:
|
||||
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
|
@ -20,7 +21,6 @@ import re
|
|||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
from PIL import Image
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
|
|
@ -804,7 +804,7 @@ class Parser(ProcessBase):
|
|||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], name, blob)
|
||||
await asyncio.to_thread(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
|
|
@ -812,6 +812,14 @@ class Parser(ProcessBase):
|
|||
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
|
||||
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in outs.get("json", []):
|
||||
tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
|
@ -13,12 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
|
|
@ -152,8 +152,9 @@ class Pipeline(Graph):
|
|||
#else:
|
||||
# cpn_obj.invoke(**last_cpn.output())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
tasks = []
|
||||
tasks.append(asyncio.create_task(invoke()))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
import trio
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
|
|
@ -129,9 +129,16 @@ class Splitter(ProcessBase):
|
|||
}
|
||||
for c, img in zip(chunks, images) if c.strip()
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
tasks = []
|
||||
for d in cks:
|
||||
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if custom_pattern:
|
||||
docs = []
|
||||
|
|
|
|||
|
|
@ -14,13 +14,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import trio
|
||||
|
||||
from common import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
|
@ -57,5 +56,5 @@ if __name__ == "__main__":
|
|||
|
||||
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
|
||||
|
||||
trio.run(pipeline.run)
|
||||
asyncio.run(pipeline.run())
|
||||
thr.result()
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
|
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
|
|||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from copy import deepcopy
|
|||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
import trio
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
|
|
@ -744,12 +743,19 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
|||
titles = []
|
||||
|
||||
chunks_res = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, chunk in enumerate(chunk_sections):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks_res.append({"chunks": chunk})
|
||||
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
|
||||
tasks = []
|
||||
for i, chunk in enumerate(chunk_sections):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks_res.append({"chunks": chunk})
|
||||
tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
for chunk in chunks_res:
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import faulthandler
|
||||
import logging
|
||||
|
|
@ -31,8 +32,6 @@ import traceback
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common import settings
|
||||
|
|
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
|||
from common.versions import get_ragflow_version
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
|
||||
|
||||
class SyncBase:
|
||||
|
|
@ -60,75 +59,102 @@ class SyncBase:
|
|||
|
||||
async def __call__(self, task: dict):
|
||||
SyncLogsService.start(task["id"], task["connector_id"])
|
||||
try:
|
||||
async with task_limiter:
|
||||
with trio.fail_after(task["timeout_secs"]):
|
||||
document_batch_generator = await self._generate(task)
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
failed_docs = 0
|
||||
for document_batch in document_batch_generator:
|
||||
if not document_batch:
|
||||
continue
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = []
|
||||
for doc in document_batch:
|
||||
doc_dict = {
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob,
|
||||
}
|
||||
# Add metadata if present
|
||||
if doc.metadata:
|
||||
doc_dict["metadata"] = doc.metadata
|
||||
docs.append(doc_dict)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
except Exception as batch_ex:
|
||||
error_msg = str(batch_ex)
|
||||
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
|
||||
|
||||
if error_code == 1267 or "collation" in error_msg.lower():
|
||||
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
|
||||
for doc in docs:
|
||||
logging.debug(f"Skipped: {doc['semantic_identifier']}")
|
||||
else:
|
||||
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
|
||||
|
||||
failed_docs += len(docs)
|
||||
continue
|
||||
async with task_limiter:
|
||||
try:
|
||||
await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
|
||||
|
||||
prefix = self._get_source_prefix()
|
||||
if failed_docs > 0:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
|
||||
else:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
except asyncio.TimeoutError:
|
||||
msg = f"Task timeout after {task['timeout_secs']} seconds"
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
|
||||
return
|
||||
|
||||
except Exception as ex:
|
||||
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
|
||||
except Exception as ex:
|
||||
msg = "\n".join([
|
||||
"".join(traceback.format_exception_only(None, ex)).strip(),
|
||||
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
|
||||
])
|
||||
SyncLogsService.update_by_id(task["id"], {
|
||||
"status": TaskStatus.FAIL,
|
||||
"full_exception_trace": msg,
|
||||
"error_msg": str(ex)
|
||||
})
|
||||
return
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
|
||||
async def _run_task_logic(self, task: dict):
|
||||
document_batch_generator = await self._generate(task)
|
||||
|
||||
doc_num = 0
|
||||
failed_docs = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
async for document_batch in document_batch_generator: # 如果是 async generator
|
||||
if not document_batch:
|
||||
continue
|
||||
|
||||
min_update = min(doc.doc_updated_at for doc in document_batch)
|
||||
max_update = max(doc.doc_updated_at for doc in document_batch)
|
||||
next_update = max(next_update, max_update)
|
||||
|
||||
docs = []
|
||||
for doc in document_batch:
|
||||
d = {
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob,
|
||||
}
|
||||
if doc.metadata:
|
||||
d["metadata"] = doc.metadata
|
||||
docs.append(d)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(
|
||||
kb, docs, task["tenant_id"],
|
||||
f"{self.SOURCE_NAME}/{task['connector_id']}",
|
||||
task["auto_parse"]
|
||||
)
|
||||
SyncLogsService.increase_docs(
|
||||
task["id"], min_update, max_update,
|
||||
len(docs), "\n".join(err), len(err)
|
||||
)
|
||||
|
||||
doc_num += len(docs)
|
||||
|
||||
except Exception as batch_ex:
|
||||
msg = str(batch_ex)
|
||||
code = getattr(batch_ex, "args", [None])[0]
|
||||
|
||||
if code == 1267 or "collation" in msg.lower():
|
||||
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
|
||||
else:
|
||||
logging.error(f"Error processing batch: {msg}")
|
||||
|
||||
failed_docs += len(docs)
|
||||
continue
|
||||
|
||||
prefix = self._get_source_prefix()
|
||||
if failed_docs > 0:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
|
||||
else:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
||||
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_source_prefix(self):
|
||||
return ""
|
||||
|
||||
|
|
@ -617,23 +643,32 @@ func_factory = {
|
|||
|
||||
|
||||
async def dispatch_tasks():
|
||||
async with trio.open_nursery() as nursery:
|
||||
while True:
|
||||
try:
|
||||
list(SyncLogsService.list_sync_tasks()[0])
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning(f"DB is not ready yet: {e}")
|
||||
await trio.sleep(3)
|
||||
while True:
|
||||
try:
|
||||
list(SyncLogsService.list_sync_tasks()[0])
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning(f"DB is not ready yet: {e}")
|
||||
await asyncio.sleep(3)
|
||||
|
||||
for task in SyncLogsService.list_sync_tasks()[0]:
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||
func = func_factory[task["source"]](task["config"])
|
||||
nursery.start_soon(func, task)
|
||||
await trio.sleep(1)
|
||||
tasks = []
|
||||
for task in SyncLogsService.list_sync_tasks()[0]:
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||
|
||||
func = func_factory[task["source"]](task["config"])
|
||||
tasks.append(asyncio.create_task(func(task)))
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
|
@ -678,4 +713,4 @@ async def main():
|
|||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
trio.run(main)
|
||||
asyncio.run(main)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import socket
|
||||
import concurrent
|
||||
# from beartype import BeartypeConf
|
||||
|
|
@ -46,7 +47,6 @@ from functools import partial
|
|||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import signal
|
||||
import trio
|
||||
import exceptiongroup
|
||||
import faulthandler
|
||||
import numpy as np
|
||||
|
|
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
|
|||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
|
||||
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO)
|
||||
kg_limiter = trio.CapacityLimiter(2)
|
||||
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
|
||||
kg_limiter = asyncio.Semaphore(2)
|
||||
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
|
||||
stop_event = threading.Event()
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ async def collect():
|
|||
|
||||
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60*80, 1)
|
||||
|
|
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
|
|||
|
||||
try:
|
||||
async with chunk_limiter:
|
||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||
cks = await asyncio.to_thread(
|
||||
chunker.chunk,
|
||||
task["name"],
|
||||
binary=binary,
|
||||
from_page=task["from_page"],
|
||||
to_page=task["to_page"],
|
||||
lang=task["language"],
|
||||
callback=progress_callback,
|
||||
kb_id=task["kb_id"],
|
||||
parser_config=task["parser_config"],
|
||||
tenant_id=task["tenant_id"],
|
||||
)
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
|
|
@ -290,9 +299,16 @@ async def build_chunks(task, progress_callback):
|
|||
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
||||
raise
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in cks:
|
||||
nursery.start_soon(upload_to_minio, doc, ck)
|
||||
tasks = []
|
||||
for ck in cks:
|
||||
tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
el = timer() - st
|
||||
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
|
||||
|
|
@ -306,15 +322,27 @@ async def build_chunks(task, progress_callback):
|
|||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
|
||||
cached = await asyncio.to_thread(
|
||||
keyword_extraction,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
topn,
|
||||
)
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["parser_config"].get("auto_questions", 0):
|
||||
|
|
@ -326,14 +354,26 @@ async def build_chunks(task, progress_callback):
|
|||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
|
||||
cached = await asyncio.to_thread(
|
||||
question_proposal,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
topn,
|
||||
)
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||
|
|
@ -371,15 +411,29 @@ async def build_chunks(task, progress_callback):
|
|||
if not picked_examples:
|
||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
|
||||
cached = await asyncio.to_thread(
|
||||
content_tagging,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
all_tags,
|
||||
picked_examples,
|
||||
topn_tags,
|
||||
)
|
||||
if cached:
|
||||
cached = json.dumps(cached)
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs_to_tag:
|
||||
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||
tasks = []
|
||||
for d in docs_to_tag:
|
||||
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
return docs
|
||||
|
|
@ -392,7 +446,7 @@ def build_TOC(task, docs, progress_callback):
|
|||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
|
||||
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
|
|
@ -440,7 +494,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
|
|
@ -452,7 +506,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
|
|
@ -535,7 +589,7 @@ async def run_dataflow(task: dict):
|
|||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
|
|
@ -742,14 +796,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
|
|
@ -766,10 +820,17 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return False
|
||||
return True
|
||||
|
|
@ -994,7 +1055,7 @@ async def handle_task():
|
|||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
await trio.sleep(5)
|
||||
await asyncio.sleep(5)
|
||||
return
|
||||
|
||||
task_type = task["task_type"]
|
||||
|
|
@ -1091,7 +1152,7 @@ async def report_status():
|
|||
logging.exception("report_status got exception")
|
||||
finally:
|
||||
redis_lock.release()
|
||||
await trio.sleep(30)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
async def task_manager():
|
||||
|
|
@ -1127,14 +1188,22 @@ async def main():
|
|||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(report_status)
|
||||
report_task = asyncio.create_task(report_status())
|
||||
tasks = []
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
await task_limiter.acquire()
|
||||
nursery.start_soon(task_manager)
|
||||
t = asyncio.create_task(task_manager())
|
||||
tasks.append(t)
|
||||
finally:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
report_task.cancel()
|
||||
await asyncio.gather(report_task, return_exceptions=True)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
trio.run(main)
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
|
|
@ -24,39 +25,53 @@ from PIL import Image
|
|||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
from rag.svr.task_executor import minio_limiter
|
||||
|
||||
if "image" not in d:
|
||||
return
|
||||
if not d["image"]:
|
||||
del d["image"]
|
||||
return
|
||||
|
||||
with BytesIO() as output_buffer:
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer.write(d["image"])
|
||||
output_buffer.seek(0)
|
||||
else:
|
||||
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
|
||||
if d["image"].mode in ("RGBA", "P"):
|
||||
converted_image = d["image"].convert("RGB")
|
||||
d["image"] = converted_image
|
||||
try:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"Saving image exception, ignore: {}".format(str(e)))
|
||||
def encode_image():
|
||||
with BytesIO() as buf:
|
||||
img = d["image"]
|
||||
|
||||
async with minio_limiter:
|
||||
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
|
||||
d["img_id"] = f"{bucket}-{objname}"
|
||||
if not isinstance(d["image"], bytes):
|
||||
d["image"].close()
|
||||
del d["image"] # Remove image reference
|
||||
if isinstance(img, bytes):
|
||||
buf.write(img)
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
if img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
try:
|
||||
img.save(buf, format="JPEG")
|
||||
except OSError as e:
|
||||
logging.warning(f"Saving image exception: {e}")
|
||||
return None
|
||||
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
jpeg_binary = await asyncio.to_thread(encode_image)
|
||||
if jpeg_binary is None:
|
||||
del d["image"]
|
||||
return
|
||||
|
||||
async with minio_limiter:
|
||||
await asyncio.to_thread(
|
||||
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
|
||||
)
|
||||
|
||||
d["img_id"] = f"{bucket}-{objname}"
|
||||
|
||||
if not isinstance(d["image"], bytes):
|
||||
d["image"].close()
|
||||
del d["image"]
|
||||
|
||||
|
||||
def id2image(image_id:str|None, storage_get_func: partial):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import uuid
|
||||
|
|
@ -22,7 +23,6 @@ import valkey as redis
|
|||
from common.decorator import singleton
|
||||
from common import settings
|
||||
from valkey.lock import Lock
|
||||
import trio
|
||||
|
||||
REDIS = {}
|
||||
try:
|
||||
|
|
@ -405,7 +405,7 @@ class RedisDistributedLock:
|
|||
while True:
|
||||
if self.lock.acquire(token=self.lock_value):
|
||||
break
|
||||
await trio.sleep(10)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
def release(self):
|
||||
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue