ragflow/rag/cv/table_recognize.py
KevinHuSh 038b36a525
build python version rag-flow (#21)
* clean rust version project

* clean rust version project

* build python version rag-flow
2024-01-15 08:46:22 +08:00

113 lines
4.7 KiB
Python

#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from transformers import \
TableTransformerForObjectDetection,\
AutoImageProcessor
from PIL import ImageDraw
from random import randint
class TableTransformer:
def __init__(self,
rec_mdlnm="microsoft/table-transformer-structure-recognition"):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
self.rec_img_pro = AutoImageProcessor.from_pretrained(rec_mdlnm)
self.rec_mdl = TableTransformerForObjectDetection.from_pretrained(
rec_mdlnm)
if torch.cuda.is_available():
self.rec_mdl.cuda()
self.batch_size = 1 # batch_size
def __friendly(self, batch_res, id2label):
res = []
for r in batch_res:
feas = []
for score, label, box in zip(r["scores"], r["labels"], r["boxes"]):
if label.item() == 0:
continue
box = [round(x, 2) for x in box.tolist()]
feas.append({
"type": id2label[label.item()],
"score": score.item(),
"bbox": box
})
res.append(feas)
return res
def __draw(self, bres, imgs, id2label):
for i, (img, r) in enumerate(zip(imgs, bres)):
draw = ImageDraw.Draw(img, "RGB")
for score, label, box in zip(r["scores"], r["labels"], r["boxes"]):
if label.item() == 0:
continue
r = randint(0, 255)
g = randint(0, 255)
b = randint(0, 255)
x0, y0, x1, y1 = box[0], box[1], box[2], box[-1]
draw.rectangle((x0, y0, x1, y1), outline=(r, g, b), width=1)
draw.text((x0, y0), id2label[label.item(
)] + ":{:.2f}".format(score), fill=(r, g, b))
img.save(f"./t{i}.%d.jpg" % randint(0, 1000))
def __call__(self, images, threshold=0.8):
res = []
for i in range(0, len(images), self.batch_size):
imgs = images[i: i + self.batch_size]
inputs = self.rec_img_pro(imgs, return_tensors="pt")
inputs = {k: inputs[k].to(self.rec_mdl.device)
if isinstance(inputs[k], torch.Tensor)
else inputs[k] for k in inputs.keys()}
outputs = self.rec_mdl(**inputs)
target_sizes = torch.tensor([img.size[::-1] for img in imgs])
# [scores, labels, boxes}]
with torch.no_grad():
bres = self.rec_img_pro.post_process_object_detection(outputs,
threshold=threshold,
target_sizes=target_sizes)
#self.__draw(bres, imgs, self.rec_mdl.config.id2label)
res.extend(self.__friendly(bres, self.rec_mdl.config.id2label))
return res
def detect(self, images):
res = []
for i in range(0, len(images), self.batch_size):
imgs = images[i: i + self.batch_size]
inputs = self.det_img_pro(imgs, return_tensors="pt")
inputs = {k: inputs[k].to(self.det_mdl.device)
if isinstance(inputs[k], torch.Tensor)
else inputs[k] for k in inputs.keys()}
outputs = self.det_mdl(**inputs)
target_sizes = torch.tensor([img.size[::-1] for img in imgs])
# [scores, labels, boxes}]
with torch.no_grad():
res.extend(self.__friendly(self.det_img_pro.post_process_object_detection(outputs,
threshold=0.9,
target_sizes=target_sizes),
self.det_mdl.config.id2label
))
return res