ragflow/python/cv/table_recognize.py
2023-12-14 19:19:03 +08:00

107 lines
4.4 KiB
Python

import torch
from transformers import \
TableTransformerForObjectDetection,\
AutoImageProcessor
from PIL import ImageDraw
from random import randint
class TableTransformer:
def __init__(self,
rec_mdlnm="microsoft/table-transformer-structure-recognition"):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
self.rec_img_pro = AutoImageProcessor.from_pretrained(rec_mdlnm)
self.rec_mdl = TableTransformerForObjectDetection.from_pretrained(
rec_mdlnm)
if torch.cuda.is_available():
self.rec_mdl.cuda()
self.batch_size = 1 # batch_size
def __friendly(self, batch_res, id2label):
res = []
for r in batch_res:
feas = []
for score, label, box in zip(r["scores"], r["labels"], r["boxes"]):
if label.item() == 0:
continue
box = [round(x, 2) for x in box.tolist()]
feas.append({
"top": box[1], "bottom": box[-1],
"x0": box[0], "x1": box[2],
"score": score.item(),
"label": id2label[label.item()]
})
wids = [f["x1"] - f["x0"]
for f in feas if f["label"].find("row") > 0]
if wids:
mw = max(wids) / 2
for f in feas:
if f["label"].find("row") > 0 and f["x1"] - f["x0"] < mw:
f["x1"] += mw
res.append(feas)
return res
def __draw(self, bres, imgs, id2label):
for i, (img, r) in enumerate(zip(imgs, bres)):
draw = ImageDraw.Draw(img, "RGB")
for score, label, box in zip(r["scores"], r["labels"], r["boxes"]):
if label.item() == 0:
continue
r = randint(0, 255)
g = randint(0, 255)
b = randint(0, 255)
x0, y0, x1, y1 = box[0], box[1], box[2], box[-1]
draw.rectangle((x0, y0, x1, y1), outline=(r, g, b), width=1)
draw.text((x0, y0), id2label[label.item(
)] + ":{:.2f}".format(score), fill=(r, g, b))
img.save(f"./t{i}.%d.jpg" % randint(0, 1000))
def __call__(self, images):
res = []
for i in range(0, len(images), self.batch_size):
imgs = images[i: i + self.batch_size]
inputs = self.rec_img_pro(imgs, return_tensors="pt")
inputs = {k: inputs[k].to(self.rec_mdl.device)
if isinstance(inputs[k], torch.Tensor)
else inputs[k] for k in inputs.keys()}
outputs = self.rec_mdl(**inputs)
target_sizes = torch.tensor([img.size[::-1] for img in imgs])
# [scores, labels, boxes}]
with torch.no_grad():
bres = self.rec_img_pro.post_process_object_detection(outputs,
threshold=0.80,
target_sizes=target_sizes)
self.__draw(bres, imgs, self.rec_mdl.config.id2label)
res.extend(self.__friendly(bres, self.rec_mdl.config.id2label))
return res
def detect(self, images):
res = []
for i in range(0, len(images), self.batch_size):
imgs = images[i: i + self.batch_size]
inputs = self.det_img_pro(imgs, return_tensors="pt")
inputs = {k: inputs[k].to(self.det_mdl.device)
if isinstance(inputs[k], torch.Tensor)
else inputs[k] for k in inputs.keys()}
outputs = self.det_mdl(**inputs)
target_sizes = torch.tensor([img.size[::-1] for img in imgs])
# [scores, labels, boxes}]
with torch.no_grad():
res.extend(self.__friendly(self.det_img_pro.post_process_object_detection(outputs,
threshold=0.9,
target_sizes=target_sizes),
self.det_mdl.config.id2label
))
return res