Merge remote-tracking branch 'upstream/main'

This commit is contained in:
billchen 2024-02-27 19:02:14 +08:00
commit c4e867e277
20 changed files with 642 additions and 118 deletions

View file

@ -20,7 +20,7 @@ from flask_login import login_required, current_user
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, huqie, retrievaler from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode from api.settings import RetCode, retrievaler
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re

View file

@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger from api.settings import access_logger, stat_logger, retrievaler
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder, rmSpace from rag.utils import num_tokens_from_string, encoder, rmSpace
@ -58,7 +56,7 @@ def set_conversation():
conv = { conv = {
"id": get_uuid(), "id": get_uuid(),
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
"name": "New conversation", "name": req.get("name", "New conversation"),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
} }
ConversationService.save(**conv) ConversationService.save(**conv)
@ -102,7 +100,7 @@ def rm():
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
convs = ConversationService.query(dialog_id=dialog_id) convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
convs = [d.to_dict() for d in convs] convs = [d.to_dict() for d in convs]
return get_json_result(data=convs) return get_json_result(data=convs)
except Exception as e: except Exception as e:

View file

@ -208,9 +208,9 @@ def user_register(user_id, user):
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
if not UserService.save(**user):return if not UserService.insert(**user):return
TenantService.save(**tenant) TenantService.insert(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
return UserService.query(email=user["email"]) return UserService.query(email=user["email"])

View file

@ -16,10 +16,12 @@
import time import time
import uuid import uuid
from api.db import LLMType from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
def init_superuser(): def init_superuser():
@ -32,7 +34,43 @@ def init_superuser():
"creator": "system", "creator": "system",
"status": "1", "status": "1",
} }
UserService.save(**user_info) tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. user name: admin, password: admin. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
def init_llm_factory(): def init_llm_factory():
@ -171,10 +209,10 @@ def init_llm_factory():
def init_web_data(): def init_web_data():
start_time = time.time() start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory() if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()
print("init web data success:{}".format(time.time() - start_time)) print("init web data success:{}".format(time.time() - start_time))

View file

@ -18,7 +18,7 @@ from datetime import datetime
import peewee import peewee
from api.db.db_models import DB from api.db.db_models import DB
from api.utils import datetime_format from api.utils import datetime_format, current_timestamp, get_uuid
class CommonService: class CommonService:
@ -66,27 +66,42 @@ class CommonService:
sample_obj = cls.model(**kwargs).save(force_insert=True) sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj return sample_obj
@classmethod
@DB.connection_context()
def insert(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def insert_many(cls, data_list, batch_size=100): def insert_many(cls, data_list, batch_size=100):
with DB.atomic(): with DB.atomic():
for d in data_list: d["create_time"] = datetime_format(datetime.now()) for d in data_list:
d["create_time"] = current_timestamp()
d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size): for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute() cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_many_by_id(cls, data_list): def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic(): with DB.atomic():
for data in data_list: for data in data_list:
data["update_time"] = cur data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute() cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_by_id(cls, pid, data): def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now()) data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute() num = cls.model.update(data).where(cls.model.id == pid).execute()
return num return num

View file

@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
from rag.nlp import search
from rag.utils import ELASTICSEARCH
# Server
API_VERSION = "v1" API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow" RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py" SERVER_MODULE = "rag_flow_server.py"
@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
class CustomEnum(Enum): class CustomEnum(Enum):
@classmethod @classmethod
def valid(cls, value): def valid(cls, value):

View file

@ -185,5 +185,11 @@ def thumbnail(filename, blob):
pass pass
def traversal_files(base):
for root, ds, fs in os.walk(base):
for f in fs:
fullname = os.path.join(root, f)
yield fullname

View file

@ -17,16 +17,16 @@ database:
name: 'rag_flow' name: 'rag_flow'
user: 'root' user: 'root'
passwd: 'infini_rag_flow' passwd: 'infini_rag_flow'
host: '123.60.95.134' host: '127.0.0.1'
port: 5455 port: 5455
max_connections: 100 max_connections: 100
stale_timeout: 30 stale_timeout: 30
minio: minio:
user: 'rag_flow' user: 'rag_flow'
passwd: 'infini_rag_flow' passwd: 'infini_rag_flow'
host: '123.60.95.134:9000' host: '127.0.0.1:9000'
es: es:
hosts: 'http://123.60.95.134:9200' hosts: 'http://127.0.0.1:9200'
user_default_llm: user_default_llm:
factory: '通义千问' factory: '通义千问'
chat_model: 'qwen-plus' chat_model: 'qwen-plus'

View file

@ -11,7 +11,36 @@ English | [简体中文](./README_zh.md)
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements, With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose. an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
There 2 parts in *Deep*Doc so far: vision and parser. There are 2 parts in *Deep*Doc so far: vision and parser.
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
```bash
python deepdoc/vision/t_ocr.py -h
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './ocr_outputs'
```
```bash
python deepdoc/vision/t_recognizer.py -h
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './layouts_outputs'
--threshold THRESHOLD
A threshold to filter out detections. Default: 0.5
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
```
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
<a name="2"></a> <a name="2"></a>
## 2. Vision ## 2. Vision
@ -19,9 +48,14 @@ There 2 parts in *Deep*Doc so far: vision and parser.
We use vision information to resolve problems as human being. We use vision information to resolve problems as human being.
- OCR. Since a lot of documents presented as images or at least be able to transform to image, - OCR. Since a lot of documents presented as images or at least be able to transform to image,
OCR is a very essential and fundamental or even universal solution for text extraction. OCR is a very essential and fundamental or even universal solution for text extraction.
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
txt files which contain the OCR text.
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://lh6.googleusercontent.com/2xdiSjaGWkZ71YdORc71Ujf7jCHmO6G-6ONklzGiUYEh3QZpjPo6MQ9eqEFX20am_cdW4Ck0YRraXEetXWnM08kJd99yhik13Cy0_YKUAq2zVGR15LzkovRAmK9iT4o3hcJ8dTpspaJKUwt6R4gN7So" width="300"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
</div> </div>
- Layout recognition. Documents from different domain may have various layouts, - Layout recognition. Documents from different domain may have various layouts,
@ -39,11 +73,18 @@ We use vision information to resolve problems as human being.
- Footer - Footer
- Reference - Reference
- Equation - Equation
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/ppstructure/docs/layout/layout.png?raw=true" width="900"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
</div> </div>
- Table Structure Recognition(TSR). Data table is a frequently used structure present data including numbers or text. - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers. And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM. Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
We have five labels for TSR task: We have five labels for TSR task:
@ -52,8 +93,15 @@ We use vision information to resolve problems as human being.
- Column header - Column header
- Projected row header - Projected row header
- Spanning cell - Spanning cell
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;"> <div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://user-images.githubusercontent.com/10793386/139559159-cd23c972-8731-48ed-91df-f3f27e9f4d79.jpg" width="900"/> <img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
</div> </div>
<a name="3"></a> <a name="3"></a>
@ -71,4 +119,4 @@ The résumé is a very complicated kind of document. A résumé which is compose
with various layouts could be resolved into structured data composed of nearly a hundred of fields. with various layouts could be resolved into structured data composed of nearly a hundred of fields.
We haven't opened the parser yet, as we open the processing method after parsing procedure. We haven't opened the parser yet, as we open the processing method after parsing procedure.

View file

@ -230,7 +230,7 @@ class HuParser:
b["H_right"] = headers[ii]["x1"] b["H_right"] = headers[ii]["x1"]
b["H"] = ii b["H"] = ii
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3) ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None: if ii is not None:
b["C"] = ii b["C"] = ii
b["C_left"] = clmns[ii]["x0"] b["C_left"] = clmns[ii]["x0"]

View file

@ -1,4 +1,49 @@
from .ocr import OCR from .ocr import OCR
from .recognizer import Recognizer from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer from .table_structure_recognizer import TableStructureRecognizer
def init_in_out(args):
from PIL import Image
import fitz
import os
import traceback
from api.utils.file_utils import traversal_files
images = []
outputs = []
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
pdf = fitz.open(fnm)
mat = fitz.Matrix(zoomin, zoomin)
for i, page in enumerate(pdf):
pix = page.get_pixmap(matrix=mat)
img = Image.frombytes("RGB", [pix.width, pix.height],
pix.samples)
images.append(img)
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
def images_and_outputs(fnm):
nonlocal outputs, images
if fnm.split(".")[-1].lower() == "pdf":
pdf_pages(fnm)
return
try:
images.append(Image.open(fnm))
outputs.append(os.path.split(fnm)[-1])
except Exception as e:
traceback.print_exc()
if os.path.isdir(args.inputs):
for fnm in traversal_files(args.inputs):
images_and_outputs(fnm)
else:
images_and_outputs(args.inputs)
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs

View file

@ -1,17 +1,26 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os import os
import re import re
from collections import Counter from collections import Counter
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from .recognizer import Recognizer from deepdoc.vision import Recognizer
class LayoutRecognizer(Recognizer): class LayoutRecognizer(Recognizer):
def __init__(self, domain): labels = [
self.layout_labels = [
"_background_", "_background_",
"Text", "Text",
"Title", "Title",
@ -24,10 +33,11 @@ class LayoutRecognizer(Recognizer):
"Reference", "Reference",
"Equation", "Equation",
] ]
super().__init__(self.layout_labels, domain, def __init__(self, domain):
super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
@ -37,7 +47,7 @@ class LayoutRecognizer(Recognizer):
return any([re.search(p, b["text"]) for p in patt]) return any([re.search(p, b["text"]) for p in patt])
layouts = super().__call__(image_list, thr, batch_size) layouts = super().__call__(image_list, thr, batch_size)
# save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7) # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
assert len(image_list) == len(ocr_res) assert len(image_list) == len(ocr_res)
# Tag layout type # Tag layout type
boxes = [] boxes = []
@ -117,3 +127,5 @@ class LayoutRecognizer(Recognizer):
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout return ocr_res, page_layout

View file

@ -2,7 +2,6 @@ import copy
import numpy as np import numpy as np
import cv2 import cv2
import paddle
from shapely.geometry import Polygon from shapely.geometry import Polygon
import pyclipper import pyclipper
@ -215,7 +214,7 @@ class DBPostProcess(object):
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if not isinstance(pred, np.ndarray):
pred = pred.numpy() pred = pred.numpy()
pred = pred[:, 0, :, :] pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list): if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor): if not isinstance(preds, np.ndarray):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)

View file

@ -17,7 +17,6 @@ from copy import deepcopy
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from . import seeit
from .operators import * from .operators import *
from rag.settings import cron_logger from rag.settings import cron_logger
@ -36,7 +35,7 @@ class Recognizer(object):
""" """
if not model_dir: if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr") model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx") model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
@ -46,6 +45,9 @@ class Recognizer(object):
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else: else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
self.label_list = label_list self.label_list = label_list
@staticmethod @staticmethod
@ -257,6 +259,18 @@ class Recognizer(object):
return max_overlaped_i return max_overlaped_i
@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i
@staticmethod @staticmethod
def find_overlapped_with_threashold(box, boxes, thr=0.3): def find_overlapped_with_threashold(box, boxes, thr=0.3):
if not boxes: if not boxes:
@ -275,23 +289,131 @@ class Recognizer(object):
return max_overlaped_i return max_overlaped_i
def preprocess(self, image_list): def preprocess(self, image_list):
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
inputs = [] inputs = []
for im_path in image_list: if "scale_factor" in self.input_names:
im, im_info = preprocess(im_path, preprocess_ops) preprocess_ops = []
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')}) for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'),
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
else:
hh, ww = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
# Scale input pixel values to 0 to 1
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
return inputs return inputs
def postprocess(self, boxes, inputs, thr):
if "scale_factor" in self.input_names:
bb = []
for b in boxes:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
return bb
def xywh2xyxy(x):
# [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def iou_filter(boxes, scores, iou_threshold):
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
boxes = np.squeeze(boxes).T
# Filter out object confidence scores below threshold
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0: return []
# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
boxes = boxes[:, :4]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
boxes = xywh2xyxy(boxes)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
indices.extend(class_indices[class_keep_boxes])
return [{
"type": self.label_list[class_ids[i]].lower(),
"bbox": [float(t) for t in boxes[i].tolist()],
"score": float(scores[i])
} for i in indices]
def __call__(self, image_list, thr=0.7, batch_size=16): def __call__(self, image_list, thr=0.7, batch_size=16):
res = [] res = []
imgs = [] imgs = []
@ -306,22 +428,14 @@ class Recognizer(object):
end_index = min((i + 1) * batch_size, len(imgs)) end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index] batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list) inputs = self.preprocess(batch_image_list)
print("preprocess")
for ins in inputs: for ins in inputs:
bb = [] bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr)
for b in self.ort_sess.run(None, ins)[0]:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
res.append(bb) res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr) #seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res return res

47
deepdoc/vision/t_ocr.py Normal file
View file

@ -0,0 +1,47 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
import numpy as np
import argparse
from deepdoc.vision import OCR, init_in_out
from deepdoc.vision.seeit import draw_box
def main(args):
ocr = OCR()
images, outputs = init_in_out(args)
for i, img in enumerate(images):
bxs = ocr(np.array(img))
bxs = [(line[0], line[1][0]) for line in bxs]
bxs = [{
"text": t,
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
"type": "ocr",
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
img = draw_box(images[i], bxs, ["ocr"], 1.)
img.save(outputs[i], quality=95)
with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
default="./ocr_outputs")
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,175 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os, sys
import re
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
import argparse
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
from deepdoc.vision.seeit import draw_box
def main(args):
images, outputs = init_in_out(args)
if args.mode.lower() == "layout":
labels = LayoutRecognizer.labels
detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
if args.mode.lower() == "tsr":
labels = TableStructureRecognizer.labels
detr = TableStructureRecognizer()
ocr = OCR()
layouts = detr(images, float(args.threshold))
for i, lyt in enumerate(layouts):
if args.mode.lower() == "tsr":
#lyt = [t for t in lyt if t["type"] == "table column"]
html = get_table_html(images[i], lyt, ocr)
with open(outputs[i]+".html", "w+") as f: f.write(html)
lyt = [{
"type": t["label"],
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
"score": t["score"]
} for t in lyt]
img = draw_box(images[i], lyt, labels, float(args.threshold))
img.save(outputs[i], quality=95)
print("save result to: " + outputs[i])
def get_table_html(img, tb_cpns, ocr):
boxes = ocr(np.array(img))
boxes = Recognizer.sort_Y_firstly(
[{"x0": b[0][0], "x1": b[1][0],
"top": b[0][1], "text": t[0],
"bottom": b[-1][1],
"layout_type": "table",
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3
)
def gather(kwd, fzy=10, ption=0.6):
nonlocal boxes
eles = Recognizer.sort_Y_firstly(
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption)
return Recognizer.sort_Y_firstly(eles, 0)
headers = gather(r".*header$")
rows = gather(r".* (row|header)")
spans = gather(r".*spanning")
clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes:
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None:
b["R"] = ii
b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"]
ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
if ii is not None:
b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"]
b["H_left"] = headers[ii]["x0"]
b["H_right"] = headers[ii]["x1"]
b["H"] = ii
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
b["C_right"] = clmns[ii]["x1"]
ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
if ii is not None:
b["H_top"] = spans[ii]["top"]
b["H_bott"] = spans[ii]["bottom"]
b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
html = """
<html>
<head>
<style>
._table_1nkzy_11 {
margin: auto;
width: 70%%;
padding: 10px;
}
._table_1nkzy_11 p {
margin-bottom: 50px;
border: 1px solid #e1e1e1;
}
caption {
color: #6ac1ca;
font-size: 20px;
height: 50px;
line-height: 50px;
font-weight: 600;
margin-bottom: 10px;
}
._table_1nkzy_11 table {
width: 100%%;
border-collapse: collapse;
}
th {
color: #fff;
background-color: #6ac1ca;
}
td:hover {
background: #c1e8e8;
}
tr:nth-child(even) {
background-color: #f2f2f2;
}
._table_1nkzy_11 th,
._table_1nkzy_11 td {
text-align: center;
border: 1px solid #ddd;
padding: 8px;
}
</style>
</head>
<body>
%s
</body>
</html>
"""% TableStructureRecognizer.construct_table(boxes, html=True)
return html
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
default="./layouts_outputs")
parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5)
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
default="layout")
args = parser.parse_args()
main(args)

View file

@ -1,8 +1,19 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging import logging
import os import os
import re import re
from collections import Counter from collections import Counter
from copy import deepcopy
import numpy as np import numpy as np
@ -12,19 +23,20 @@ from .recognizer import Recognizer
class TableStructureRecognizer(Recognizer): class TableStructureRecognizer(Recognizer):
labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
def __init__(self): def __init__(self):
self.labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
super().__init__(self.labels, "tsr", super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, images, thr=0.5): def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)
res = [] res = []
# align left&right for rows, align top&bottom for columns # align left&right for rows, align top&bottom for columns
@ -43,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
"row") > 0 or b["label"].find("header") > 0] "row") > 0 or b["label"].find("header") > 0]
if not left: if not left:
continue continue
left = np.median(left) if len(left) > 4 else np.min(left) left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.median(right) if len(right) > 4 else np.max(right) right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts: for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0: if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left: if b["x0"] > left:
@ -79,7 +91,8 @@ class TableStructureRecognizer(Recognizer):
return True return True
return False return False
def __blockType(self, b): @staticmethod
def blockType(b):
patt = [ patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"), (r"^(20|19)[0-9]{2}年$", "Dt"),
@ -109,11 +122,13 @@ class TableStructureRecognizer(Recognizer):
return "Ot" return "Ot"
def construct_table(self, boxes, is_english=False, html=False): @staticmethod
def construct_table(boxes, is_english=False, html=False):
cap = "" cap = ""
i = 0 i = 0
while i < len(boxes): while i < len(boxes):
if self.is_caption(boxes[i]): if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: cap + " "
cap += boxes[i]["text"] cap += boxes[i]["text"]
boxes.pop(i) boxes.pop(i)
i -= 1 i -= 1
@ -122,14 +137,15 @@ class TableStructureRecognizer(Recognizer):
if not boxes: if not boxes:
return [] return []
for b in boxes: for b in boxes:
b["btype"] = self.__blockType(b) b["btype"] = TableStructureRecognizer.blockType(b)
max_type = Counter([b["btype"] for b in boxes]).items() max_type = Counter([b["btype"] for b in boxes]).items()
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
logging.debug("MAXTYPE: " + max_type) logging.debug("MAXTYPE: " + max_type)
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
rowh = np.min(rowh) if rowh else 0 rowh = np.min(rowh) if rowh else 0
boxes = self.sort_R_firstly(boxes, rowh / 2) boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
#for b in boxes:print(b)
boxes[0]["rn"] = 0 boxes[0]["rn"] = 0
rows = [[boxes[0]]] rows = [[boxes[0]]]
btm = boxes[0]["bottom"] btm = boxes[0]["bottom"]
@ -150,9 +166,9 @@ class TableStructureRecognizer(Recognizer):
colwm = np.min(colwm) if colwm else 0 colwm = np.min(colwm) if colwm else 0
crosspage = len(set([b["page_number"] for b in boxes])) > 1 crosspage = len(set([b["page_number"] for b in boxes])) > 1
if crosspage: if crosspage:
boxes = self.sort_X_firstly(boxes, colwm / 2, False) boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False)
else: else:
boxes = self.sort_C_firstly(boxes, colwm / 2) boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
boxes[0]["cn"] = 0 boxes[0]["cn"] = 0
cols = [[boxes[0]]] cols = [[boxes[0]]]
right = boxes[0]["x1"] right = boxes[0]["x1"]
@ -313,16 +329,18 @@ class TableStructureRecognizer(Recognizer):
hdset.add(i) hdset.add(i)
if html: if html:
return [self.__html_table(cap, hdset, return TableStructureRecognizer.__html_table(cap, hdset,
self.__cal_spans(boxes, rows, TableStructureRecognizer.__cal_spans(boxes, rows,
cols, tbl, True) cols, tbl, True)
)] )
return self.__desc_table(cap, hdset, return TableStructureRecognizer.__desc_table(cap, hdset,
self.__cal_spans(boxes, rows, cols, tbl, False), TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
is_english) False),
is_english)
def __html_table(self, cap, hdset, tbl): @staticmethod
def __html_table(cap, hdset, tbl):
# constrcut HTML # constrcut HTML
html = "<table>" html = "<table>"
if cap: if cap:
@ -339,8 +357,8 @@ class TableStructureRecognizer(Recognizer):
txt = "" txt = ""
if arr: if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
txt = "".join([c["text"] txt = " ".join([c["text"]
for c in self.sort_Y_firstly(arr, h)]) for c in Recognizer.sort_Y_firstly(arr, h)])
txts.append(txt) txts.append(txt)
sp = "" sp = ""
if arr[0].get("colspan"): if arr[0].get("colspan"):
@ -366,7 +384,8 @@ class TableStructureRecognizer(Recognizer):
html += "\n</table>" html += "\n</table>"
return html return html
def __desc_table(self, cap, hdr_rowno, tbl, is_english): @staticmethod
def __desc_table(cap, hdr_rowno, tbl, is_english):
# get text of every colomn in header row to become header text # get text of every colomn in header row to become header text
clmno = len(tbl[0]) clmno = len(tbl[0])
rowno = len(tbl) rowno = len(tbl)
@ -379,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
for i in range(clmno): for i in range(clmno):
if not tbl[r][i]: if not tbl[r][i]:
continue continue
txt = "".join([a["text"].strip() for a in tbl[r][i]]) txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt headers[r][i] = txt
hdrset.add(txt) hdrset.add(txt)
if all([not t for t in headers[r]]): if all([not t for t in headers[r]]):
@ -469,7 +488,8 @@ class TableStructureRecognizer(Recognizer):
row_txt = [t + f"\t——{from_}{cap}" for t in row_txt] row_txt = [t + f"\t——{from_}{cap}" for t in row_txt]
return row_txt return row_txt
def __cal_spans(self, boxes, rows, cols, tbl, html=True): @staticmethod
def __cal_spans(boxes, rows, cols, tbl, html=True):
# caculate span # caculate span
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
for cln in cols] for cln in cols]
@ -553,4 +573,3 @@ class TableStructureRecognizer(Recognizer):
tbl[rowspan[0]][colspan[0]] = arr tbl[rowspan[0]][colspan[0]] = arr
return tbl return tbl

View file

@ -15,7 +15,7 @@
# #
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import os import openai
class Base(ABC): class Base(ABC):
@ -33,11 +33,14 @@ class GptTurbo(Base):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system: history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create( try:
model=self.model_name, res = self.client.chat.completions.create(
messages=history, model=self.model_name,
**gen_conf) messages=history,
return res.choices[0].message.content.strip(), res.usage.completion_tokens **gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens
except openai.APIError as e:
return "ERROR: "+str(e), 0
from dashscope import Generation from dashscope import Generation
@ -58,7 +61,7 @@ class QWenChat(Base):
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0 return "ERROR: " + response.message, 0
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
@ -77,4 +80,4 @@ class ZhipuChat(Base):
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0 return "ERROR: " + response.message, 0

View file

@ -1,7 +1,4 @@
from . import search
from rag.utils import ELASTICSEARCH
retrievaler = search.Dealer(ELASTICSEARCH)
from nltk.stem import PorterStemmer from nltk.stem import PorterStemmer
stemmer = PorterStemmer() stemmer = PorterStemmer()
@ -39,10 +36,12 @@ BULLET_PATTERN = [[
] ]
] ]
def random_choices(arr, k): def random_choices(arr, k):
k = min(len(arr), k) k = min(len(arr), k)
return random.choices(arr, k=k) return random.choices(arr, k=k)
def bullets_category(sections): def bullets_category(sections):
global BULLET_PATTERN global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN) hits = [0] * len(BULLET_PATTERN)

View file

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import re import re
from elasticsearch_dsl import Q, Search, A from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
@ -183,6 +183,7 @@ class Dealer:
def insert_citations(self, answer, chunks, chunk_v, def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7): embd_mdl, tkweight=0.3, vtweight=0.7):
assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer) pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)): for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]): if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@ -216,7 +217,7 @@ class Dealer:
if mx < 0.55: if mx < 0.55:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
res = "" res = ""
for i, p in enumerate(pieces): for i, p in enumerate(pieces):
@ -225,6 +226,7 @@ class Dealer:
continue continue
if i not in cites: if i not in cites:
continue continue
for c in cites[i]: assert int(c) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i]) res += "##%s$$" % "$".join(cites[i])
return res return res