mirror of
https://github.com/docling-project/docling-ibm-models.git
synced 2026-05-17 13:10:52 +00:00
+60
@@ -0,0 +1,60 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Tmp files and directories
|
||||
stderr.*
|
||||
stdout.*
|
||||
*.tar
|
||||
test.sh
|
||||
OutputDecoder
|
||||
jobs.txt
|
||||
_std*.*
|
||||
tests/tmp/*
|
||||
runs/*
|
||||
*.onnx
|
||||
.DS_Store
|
||||
viz/
|
||||
|
||||
# VIM
|
||||
*.swp
|
||||
*.swo
|
||||
*.bak
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
_venv/
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
venv
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# checkpoint file for testing
|
||||
tests/test_data/model_artifacts/*.check
|
||||
tests/test_data/model_artifacts/*.json
|
||||
tests/test_data/model_artifacts/*.pt
|
||||
|
||||
# test results
|
||||
tests/test_data/viz/
|
||||
@@ -0,0 +1,43 @@
|
||||
fail_fast: true
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: system
|
||||
name: Black
|
||||
entry: poetry run black docling_ibm_models
|
||||
pass_filenames: false
|
||||
language: system
|
||||
files: '\.py$'
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: system
|
||||
name: isort
|
||||
entry: poetry run isort docling_ibm_models
|
||||
pass_filenames: false
|
||||
language: system
|
||||
files: '\.py$'
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: system
|
||||
name: Poetry check
|
||||
entry: poetry lock --check
|
||||
pass_filenames: false
|
||||
language: system
|
||||
|
||||
# Ready to be enabled soon
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: system
|
||||
# name: flake8
|
||||
# entry: poetry run flake8 docling_ibm_models
|
||||
# pass_filenames: false
|
||||
# language: system
|
||||
# files: '\.py$'
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: system
|
||||
# name: MyPy
|
||||
# entry: poetry run mypy docling_ibm_models
|
||||
# pass_filenames: false
|
||||
# language: system
|
||||
# files: '\.py$'
|
||||
@@ -0,0 +1,129 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement using
|
||||
[deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com).
|
||||
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html).
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
Homepage: [https://www.contributor-covenant.org](https://www.contributor-covenant.org)
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at
|
||||
[https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations).
|
||||
+163
@@ -0,0 +1,163 @@
|
||||
## Contributing In General
|
||||
Our project welcomes external contributions. If you have an itch, please feel
|
||||
free to scratch it.
|
||||
|
||||
To contribute code or documentation, please submit a [pull request](https://github.com/DS4SD/docling/pulls).
|
||||
|
||||
A good way to familiarize yourself with the codebase and contribution process is
|
||||
to look for and tackle low-hanging fruit in the [issue tracker](https://github.com/DS4SD/docling/issues).
|
||||
Before embarking on a more ambitious contribution, please quickly [get in touch](#communication) with us.
|
||||
|
||||
For general questions or support requests, please refer to the [discussion section](https://github.com/DS4SD/docling/discussions).
|
||||
|
||||
**Note: We appreciate your effort, and want to avoid a situation where a contribution
|
||||
requires extensive rework (by you or by us), sits in backlog for a long time, or
|
||||
cannot be accepted at all!**
|
||||
|
||||
### Proposing new features
|
||||
|
||||
If you would like to implement a new feature, please [raise an issue](https://github.com/DS4SD/docling/issues)
|
||||
before sending a pull request so the feature can be discussed. This is to avoid
|
||||
you wasting your valuable time working on a feature that the project developers
|
||||
are not interested in accepting into the code base.
|
||||
|
||||
### Fixing bugs
|
||||
|
||||
If you would like to fix a bug, please [raise an issue](https://github.com/DS4SD/docling/issues) before sending a
|
||||
pull request so it can be tracked.
|
||||
|
||||
### Merge approval
|
||||
|
||||
The project maintainers use LGTM (Looks Good To Me) in comments on the code
|
||||
review to indicate acceptance. A change requires LGTMs from two of the
|
||||
maintainers of each component affected.
|
||||
|
||||
For a list of the maintainers, see the [MAINTAINERS.md](MAINTAINERS.md) page.
|
||||
|
||||
|
||||
## Legal
|
||||
|
||||
Each source file must include a license header for the MIT
|
||||
Software. Using the SPDX format is the simplest approach.
|
||||
e.g.
|
||||
|
||||
```
|
||||
/*
|
||||
Copyright IBM Inc. All rights reserved.
|
||||
|
||||
SPDX-License-Identifier: MIT
|
||||
*/
|
||||
```
|
||||
|
||||
We have tried to make it as easy as possible to make contributions. This
|
||||
applies to how we handle the legal aspects of contribution. We use the
|
||||
same approach - the [Developer's Certificate of Origin 1.1 (DCO)](https://github.com/hyperledger/fabric/blob/master/docs/source/DCO1.1.txt) - that the Linux® Kernel [community](https://elinux.org/Developer_Certificate_Of_Origin)
|
||||
uses to manage code contributions.
|
||||
|
||||
We simply ask that when submitting a patch for review, the developer
|
||||
must include a sign-off statement in the commit message.
|
||||
|
||||
Here is an example Signed-off-by line, which indicates that the
|
||||
submitter accepts the DCO:
|
||||
|
||||
```
|
||||
Signed-off-by: John Doe <john.doe@example.com>
|
||||
```
|
||||
|
||||
You can include this automatically when you commit a change to your
|
||||
local git repository using the following command:
|
||||
|
||||
```
|
||||
git commit -s
|
||||
```
|
||||
|
||||
|
||||
## Communication
|
||||
|
||||
Please feel free to connect with us using the [discussion section](https://github.com/DS4SD/docling/discussions).
|
||||
|
||||
|
||||
|
||||
## Developing
|
||||
|
||||
### Usage of Poetry
|
||||
|
||||
We use Poetry to manage dependencies.
|
||||
|
||||
|
||||
#### Install
|
||||
|
||||
To install, see the documentation here: https://python-poetry.org/docs/master/#installing-with-the-official-installer
|
||||
|
||||
1. Install the Poetry globally in your machine
|
||||
```bash
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
```
|
||||
The installation script will print the installation bin folder `POETRY_BIN` which you need in the next steps.
|
||||
|
||||
2. Make sure Poetry is in your `$PATH`
|
||||
- for `zsh`
|
||||
```sh
|
||||
echo 'export PATH="POETRY_BIN:$PATH"' >> ~/.zshrc
|
||||
```
|
||||
- for `bash`
|
||||
```sh
|
||||
echo 'export PATH="POETRY_BIN:$PATH"' >> ~/.bashrc
|
||||
```
|
||||
|
||||
3. The official guidelines linked above include useful details on the configuration of autocomplete for most shell environments, e.g. Bash and Zsh.
|
||||
|
||||
|
||||
#### Create a Virtual Environment and Install Dependencies
|
||||
|
||||
To activate the Virtual Environment, run:
|
||||
|
||||
```bash
|
||||
poetry shell
|
||||
```
|
||||
|
||||
To spawn a shell with the Virtual Environment activated. If the Virtual Environment doesn't exist, Poetry will create one for you. Then, to install dependencies, run:
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
**(Advanced) Use a Specific Python Version**
|
||||
|
||||
If for whatever reason you need to work in a specific (older) version of Python, run:
|
||||
|
||||
```bash
|
||||
poetry env use $(which python3.11)
|
||||
```
|
||||
|
||||
This creates a Virtual Environment with Python 3.11. For other versions, replace `$(which python3.11)` by the path to the interpreter (e.g., `/usr/bin/python3.11`) or use `$(which pythonX.Y)`.
|
||||
|
||||
|
||||
#### Add a new dependency
|
||||
|
||||
```bash
|
||||
poetry add NAME
|
||||
```
|
||||
|
||||
## Coding style guidelines
|
||||
|
||||
We use the following tools to enforce code style:
|
||||
|
||||
- iSort, to sort imports
|
||||
- Black, to format code
|
||||
|
||||
|
||||
We run a series of checks on the code base on every commit, using `pre-commit`. To install the hooks, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
To run the checks on-demand, run:
|
||||
|
||||
```
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Note: Checks like `Black` and `isort` will "fail" if they modify files. This is because `pre-commit` doesn't like to see files modified by their Hooks. In these cases, `git add` the modified files and `git commit` again.
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) [year] [fullname]
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,10 @@
|
||||
# MAINTAINERS
|
||||
|
||||
- Christoph Auer - [@cau-git](https://github.com/cau-git)
|
||||
- Michele Dolfi - [@dolfim-ibm](https://github.com/dolfim-ibm)
|
||||
- Maxim Lysak - [@maxmnemonic](https://github.com/maxmnemonic)
|
||||
- Nikos Livathinos - [@nikos-livathinos](https://github.com/nikos-livathinos)
|
||||
- Ahmed Nassar - [@nassarofficial](https://github.com/nassarofficial)
|
||||
- Peter Staar - [@PeterStaar-IBM](https://github.com/PeterStaar-IBM)
|
||||
|
||||
Maintainers can be contacted at [deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com).
|
||||
@@ -0,0 +1,127 @@
|
||||
# Docling-models
|
||||
|
||||
AI modules to support the Dockling PDF document conversion project.
|
||||
|
||||
- TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
|
||||
- Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
|
||||
|
||||
|
||||
## Installation Instructions
|
||||
|
||||
### MacOS / Linux
|
||||
|
||||
To install `poetry` locally, use either `pip` or `homebrew`.
|
||||
|
||||
To install `poetry` on a docker container, do the following:
|
||||
```
|
||||
ENV POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
|
||||
# Install poetry
|
||||
RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
|
||||
&& python install-poetry.py \
|
||||
&& poetry --version \
|
||||
&& rm install-poetry.py
|
||||
```
|
||||
|
||||
To install and run the package, simply set up a poetry environment
|
||||
|
||||
```
|
||||
poetry env use $(which python3.11)
|
||||
poetry shell
|
||||
```
|
||||
|
||||
and install all the dependencies,
|
||||
|
||||
```
|
||||
poetry install # this will only install the deps from the poetry.lock
|
||||
|
||||
poetry install --no-dev # this will skip installing dev dependencies
|
||||
```
|
||||
|
||||
To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
|
||||
```
|
||||
poetry update
|
||||
```
|
||||
|
||||
|
||||
## Pipeline Overview
|
||||

|
||||
|
||||
## Datasets
|
||||
Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
|
||||
|
||||
|
||||
| Name | Description | URL |
|
||||
| ------------- |:-------------:|----|
|
||||
| PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
|
||||
| FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
|
||||
| TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
|
||||
|
||||
## Models
|
||||
|
||||
### TableModel04:
|
||||

|
||||
**TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
|
||||
|
||||
|
||||
## Configuration file
|
||||
|
||||
Example configuration can be seen inside test `tests/test_tf_predictor.py`
|
||||
These are the main sections of the configuration file:
|
||||
|
||||
- `dataset`: The directory for prepared data and the parameters used during the data loading.
|
||||
- `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
|
||||
trained checkpoint files.
|
||||
- `train`: Parameters for the training of the model.
|
||||
- `predict`: Parameters for the evaluation of the model.
|
||||
- `dataset_wordmap`: Very important part that contains token maps.
|
||||
|
||||
|
||||
## Model weights
|
||||
|
||||
You can download the model weights and config files from the links:
|
||||
|
||||
- [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
|
||||
- [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
|
||||
|
||||
Place the downloaded files into `tests/test_data/model_artifacts/` directory.
|
||||
|
||||
|
||||
## Inference Tests
|
||||
|
||||
This contains unit tests for Docling models.
|
||||
|
||||
First download the model weights (see above), then run:
|
||||
```
|
||||
./devtools/check_code.sh
|
||||
```
|
||||
|
||||
This will also generate prediction and matching visualizations that can be found here:
|
||||
`tests\test_data\viz\`
|
||||
|
||||
Visualization outlines:
|
||||
- `Light Pink`: border of recognized table
|
||||
- `Grey`: OCR cells
|
||||
- `Green`: prediction bboxes
|
||||
- `Red`: OCR cells matched with prediction
|
||||
- `Blue`: Post processed, match
|
||||
- `Bold Blue`: column header
|
||||
- `Bold Magenta`: row header
|
||||
- `Bold Brown`: section row (if table have one)
|
||||
|
||||
|
||||
## Demo
|
||||
|
||||
A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
|
||||
`png` images and visualize the predictions inside another directory `<viz_dir>`.
|
||||
|
||||
First download the model weights (see above), then run:
|
||||
```
|
||||
python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
|
||||
```
|
||||
|
||||
e.g.
|
||||
```
|
||||
python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
|
||||
```
|
||||
@@ -0,0 +1,126 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
|
||||
ARTIFACT_PATH = "tests/test_data/model_artifacts"
|
||||
|
||||
|
||||
def demo(
|
||||
logger: logging.Logger,
|
||||
artifact_path: str,
|
||||
num_threads: int,
|
||||
img_dir: str,
|
||||
viz_dir: str,
|
||||
):
|
||||
r"""
|
||||
Apply LayoutPredictor on the input image directory
|
||||
|
||||
If you want to load from PDF:
|
||||
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
|
||||
"""
|
||||
# Create the layout predictor
|
||||
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
|
||||
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
|
||||
|
||||
# Predict all test png images
|
||||
for img_fn in Path(img_dir).rglob("*.png"):
|
||||
logger.info("Predicting '%s'...", img_fn)
|
||||
start_t = time.time()
|
||||
|
||||
with Image.open(img_fn) as image:
|
||||
# Predict layout
|
||||
preds = list(lpredictor.predict(image))
|
||||
dt_ms = 1000 * (time.time() - start_t)
|
||||
logger.debug("Time elapsed for prediction(ms): %s", dt_ms)
|
||||
|
||||
# Draw predictions
|
||||
out_img = image.copy()
|
||||
draw = ImageDraw.Draw(out_img)
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
scr = pred["confidence"]
|
||||
lab = pred["label"]
|
||||
box = [
|
||||
round(pred["l"]),
|
||||
round(pred["t"]),
|
||||
round(pred["r"]),
|
||||
round(pred["b"]),
|
||||
]
|
||||
|
||||
if lab == "Table":
|
||||
draw.rectangle(
|
||||
box,
|
||||
outline="red",
|
||||
)
|
||||
draw.text(
|
||||
(box[0], box[1]),
|
||||
text=str(lab),
|
||||
fill="blue",
|
||||
)
|
||||
logger.info("Table %s: bbox=%s", i, box)
|
||||
|
||||
save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
|
||||
out_img.save(save_fn)
|
||||
logger.info("Saving prediction visualization in: '%s'", save_fn)
|
||||
|
||||
|
||||
def main(args):
|
||||
r""" """
|
||||
num_threads = int(args.num_threads) if args.num_threads is not None else None
|
||||
img_dir = args.img_dir
|
||||
viz_dir = args.viz_dir
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger("LayoutPredictor")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if not logger.hasHandlers():
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Ensure the viz dir
|
||||
Path(viz_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Test the LayoutPredictor
|
||||
demo(logger, ARTIFACT_PATH, num_threads, img_dir, viz_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r"""
|
||||
python -m demo.demo_layout_predictor -i <images_dir>
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
|
||||
parser.add_argument(
|
||||
"-n", "--num_threads", required=False, default=None, help="Number of threads"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--img_dir",
|
||||
required=True,
|
||||
help="PNG images input directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--viz_dir",
|
||||
required=False,
|
||||
default="viz/",
|
||||
help="Directory to save prediction visualizations",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Executable
+96
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Disabled pylint messages
|
||||
# C0114, C0116: Missing module docstring (missing-module-docstring)
|
||||
# C0209: Formatting a regular string which could be a f-string (consider-using-f-string)
|
||||
# C0103: Variable name doesn't conform to snake_case naming style (invalid-name)
|
||||
# R0801: Similar lines in %s files %s
|
||||
# W0621: Redefining name from outer scope
|
||||
# W1514: Using open without explicitly specifying an encoding (unspecified-encoding)
|
||||
# R0912: Too many branches (too-many-branches)
|
||||
# R0913: Too many arguments
|
||||
# R0914: Too many local variables (too-many-locals)
|
||||
# R0915: Too many statements (too-many-statements)
|
||||
# R1702: Too many nested blocks (too-many-nested-blocks)
|
||||
# R0902: Too many instance attributes
|
||||
# R0903: Too few public methods
|
||||
# W0221: Arguments differ
|
||||
# C0415: Import outside toplevel
|
||||
# C0302: Too many lines in module
|
||||
# W0718: Catching too general exception Exception
|
||||
# R0902: Too many instance attributes
|
||||
# R1702: Too many nested blocks
|
||||
PYLINT_DISABLED="C0114,C0116,C0209,C0103,R0801,W0621,W1514,R0912,R0913,R0914,R0915,R1702"
|
||||
PYLINT_DISABLED+=",R0902,R0903,W0221,C0415,C0302,R0401,W0718,R0902,R1702"
|
||||
|
||||
readonly MAX_LINE_LENGTH=100
|
||||
readonly INDENT_SPACES=4
|
||||
|
||||
##########################################################################################
|
||||
# Functions
|
||||
#
|
||||
|
||||
Usage() {
|
||||
echo "Check codebase with "
|
||||
echo "Usage:"
|
||||
echo "$0 [-c] [-h]"
|
||||
echo
|
||||
echo "-c: Clear cache before invoking PyTest"
|
||||
echo "-h: Print this help message"
|
||||
echo
|
||||
echo "$0"
|
||||
}
|
||||
|
||||
|
||||
##########################################################################################
|
||||
# Main
|
||||
#
|
||||
clear_cache=0
|
||||
while getopts ":hc" option; do
|
||||
case "${option}" in
|
||||
c ) clear_cache=1;;
|
||||
h ) Usage; exit;;
|
||||
\? ) Usage; exit;;
|
||||
: ) # Missing required argument
|
||||
Usage; exit;;
|
||||
esac
|
||||
done
|
||||
|
||||
# PEP8
|
||||
echo "Flake8 check:"
|
||||
flake8 \
|
||||
--max-line-length=${MAX_LINE_LENGTH} \
|
||||
--indent-size=${INDENT_SPACES} \
|
||||
--ignore=E121,E123,E126,E226,E24,E704,W503,W504,W605,E203 \
|
||||
--extend-exclude '_*' \
|
||||
docling_ibm_models/ tests/
|
||||
echo "Flake8 - OK"
|
||||
echo
|
||||
|
||||
# # Pylint
|
||||
# echo "Pylint check:"
|
||||
# indent_string=$(printf '%*s' ${INDENT_SPACES} "" | tr ' ' 'n' | tr 'n' ' ')
|
||||
# # echo "indent_string: '${indent_string}'"
|
||||
# pylint \
|
||||
# --max-line-length ${MAX_LINE_LENGTH} \
|
||||
# --indent-string "${indent_string}" \
|
||||
# --disable ${PYLINT_DISABLED} \
|
||||
# --extension-pkg-whitelist='pydantic' \
|
||||
# --ignore-patterns '[!_]' \
|
||||
# docling_ibm_models/ tests/
|
||||
#
|
||||
# echo "Pylint check - OK"
|
||||
# echo
|
||||
|
||||
# Unit tests with PyTest
|
||||
echo "PyTest:"
|
||||
if [ ${clear_cache} -eq 1 ]; then
|
||||
echo "Clear pytest cache first"
|
||||
echo
|
||||
python -m pytest -n auto --cache-clear --ignore=docling_ibm_models/ tests/
|
||||
else
|
||||
python -m pytest -n auto --ignore=docling_ibm_models/ tests/
|
||||
fi
|
||||
echo "PyTest check - OK"
|
||||
@@ -0,0 +1,171 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from PIL import Image
|
||||
|
||||
MODEL_CHECKPOINT_FN = "model.pt"
|
||||
DEFAULT_NUM_THREADS = 4
|
||||
|
||||
|
||||
# Classes:
|
||||
CLASSES_MAP = {
|
||||
0: "background",
|
||||
1: "Caption",
|
||||
2: "Footnote",
|
||||
3: "Formula",
|
||||
4: "List-item",
|
||||
5: "Page-footer",
|
||||
6: "Page-header",
|
||||
7: "Picture",
|
||||
8: "Section-header",
|
||||
9: "Table",
|
||||
10: "Text",
|
||||
11: "Title",
|
||||
12: "Document Index",
|
||||
13: "Code",
|
||||
14: "Checkbox-Selected",
|
||||
15: "Checkbox-Unselected",
|
||||
16: "Form",
|
||||
17: "Key-Value Region",
|
||||
}
|
||||
|
||||
|
||||
class LayoutPredictor:
|
||||
r"""
|
||||
Document layout prediction using ONNX
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
|
||||
):
|
||||
r"""
|
||||
Provide the artifact path that contains the LayoutModel ONNX file
|
||||
|
||||
The number of threads is decided, in the following order, by:
|
||||
1. The init method parameter `num_threads`, if it is set.
|
||||
2. The envvar "OMP_NUM_THREADS", if it is set.
|
||||
3. The default value DEFAULT_NUM_THREADS.
|
||||
|
||||
The execution provided is decided, in the following order:
|
||||
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
|
||||
it uses the "CPUExecutionProvider".
|
||||
3. Otherwise if the "CUDAExecutionProvider" is present, use:
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
artifact_path: Path for the model ONNX file.
|
||||
num_threads: (Optional) Number of threads to run the inference.
|
||||
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError when the model's ONNX file is missing
|
||||
"""
|
||||
# Set basic params
|
||||
self._threshold = 0.6 # Score threshold
|
||||
self._image_size = 640
|
||||
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
|
||||
|
||||
# Get env vars
|
||||
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
|
||||
if num_threads is None:
|
||||
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
|
||||
self._num_threads = num_threads
|
||||
|
||||
# Decide the execution providers
|
||||
if (
|
||||
not self._use_cpu_only
|
||||
and "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
):
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
self._providers = providers
|
||||
|
||||
# Model ONNX file
|
||||
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
|
||||
if not os.path.isfile(self._onnx_fn):
|
||||
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
|
||||
|
||||
# ONNX options
|
||||
self._options = ort.SessionOptions()
|
||||
self._options.intra_op_num_threads = self._num_threads
|
||||
self.sess = ort.InferenceSession(
|
||||
self._onnx_fn,
|
||||
sess_options=self._options,
|
||||
providers=self._providers,
|
||||
)
|
||||
|
||||
def info(self) -> dict:
|
||||
r"""
|
||||
Get information about the configuration of LayoutPredictor
|
||||
"""
|
||||
info = {
|
||||
"onnx_file": self._onnx_fn,
|
||||
"intra_op_num_threads": self._num_threads,
|
||||
"use_cpu_only": self._use_cpu_only,
|
||||
"providers": self._providers,
|
||||
"image_size": self._image_size,
|
||||
"threshold": self._threshold,
|
||||
}
|
||||
return info
|
||||
|
||||
def predict(self, orig_img: Union[Image, np.array]) -> Iterable[dict]:
|
||||
r"""
|
||||
Predict bounding boxes for a given image.
|
||||
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
|
||||
[left, top, right, bottom]
|
||||
|
||||
Parameter
|
||||
---------
|
||||
origin_img: Image to be predicted as a PIL Image object or numpy array.
|
||||
|
||||
Yield
|
||||
-----
|
||||
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError when the input image is not supported
|
||||
"""
|
||||
# Convert image format
|
||||
if isinstance(orig_img, Image.Image):
|
||||
page_img = orig_img.convert("RGB")
|
||||
elif isinstance(orig_img, np.ndarray):
|
||||
page_img = Image.fromarray(orig_img).convert("RGB")
|
||||
else:
|
||||
raise TypeError("Not supported input image format")
|
||||
|
||||
w, h = page_img.size
|
||||
page_img = page_img.resize((self._image_size, self._image_size))
|
||||
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
|
||||
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
|
||||
|
||||
# Predict
|
||||
labels, boxes, scores = self.sess.run(
|
||||
output_names=None,
|
||||
input_feed={
|
||||
"images": page_data,
|
||||
"orig_target_sizes": self._size,
|
||||
},
|
||||
)
|
||||
|
||||
# Yield output
|
||||
for label, box, score in zip(labels[0], boxes[0], scores[0]):
|
||||
if score > self._threshold:
|
||||
yield {
|
||||
"l": box[0] / self._image_size * w,
|
||||
"t": box[1] / self._image_size * h,
|
||||
"r": box[2] / self._image_size * w,
|
||||
"b": box[3] / self._image_size * h,
|
||||
"label": CLASSES_MAP[label],
|
||||
"confidence": score,
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
|
||||
|
||||
LOG_LEVEL = logging.DEBUG
|
||||
logger = s.get_custom_logger("common", LOG_LEVEL)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
r"""
|
||||
Validate the provided configuration file.
|
||||
A ValueError exception will be thrown in case the config file is invalid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dictionary
|
||||
Configuration for the tablemodel
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool : True on success
|
||||
"""
|
||||
if "model" not in config:
|
||||
return True
|
||||
if "preparation" not in config:
|
||||
return True
|
||||
assert (
|
||||
"max_tag_len" in config["preparation"]
|
||||
), "Config error: 'preparation.max_tag_len' parameter is missing"
|
||||
if "seq_len" in config["model"]:
|
||||
assert (
|
||||
config["model"]["seq_len"] > 0
|
||||
), "Config error: 'model.seq_len' should be positive"
|
||||
assert config["model"]["seq_len"] <= (
|
||||
config["preparation"]["max_tag_len"] + 2
|
||||
), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2"
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
r"""
|
||||
Parse the input arguments
|
||||
A ValueError exception will be thrown in case the config file is invalid
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Train the TableModel")
|
||||
parser.add_argument(
|
||||
"-c", "--config", required=True, default=None, help="configuration file (JSON)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config_filename = args.config
|
||||
|
||||
assert os.path.isfile(config_filename), "FAILURE: Config file not found."
|
||||
return read_config(config_filename)
|
||||
|
||||
|
||||
def read_config(config_filename):
|
||||
with open(config_filename, "r") as fd:
|
||||
config = json.load(fd)
|
||||
|
||||
# Validate the config file
|
||||
validate_config(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def safe_get_parameter(input_dict, index_path, default=None, required=False):
|
||||
r"""
|
||||
Safe get parameter from a nested dictionary.
|
||||
|
||||
Provide a nested dictionary (dictionary of dictionaries) and a list of indices:
|
||||
- If the whole index path exists the value pointed by it is returned
|
||||
- Otherwise the default value is returned.
|
||||
|
||||
Input:
|
||||
input_dict: Data structure of nested dictionaries.
|
||||
index_path: List with the indices path to follow inside the input_dict.
|
||||
default: Default value to return if the indices path is broken.
|
||||
required: If true a ValueError exception will be raised in case the parameter does not exist
|
||||
Output:
|
||||
The value pointed by the index path or "default".
|
||||
"""
|
||||
if input_dict is None or index_path is None:
|
||||
return default
|
||||
|
||||
d = input_dict
|
||||
for i in index_path[:-1]:
|
||||
if i not in d:
|
||||
if required:
|
||||
raise ValueError("Missing parameter: {}".format(i))
|
||||
return default
|
||||
d = d[i]
|
||||
|
||||
last_index = index_path[-1]
|
||||
if last_index not in d:
|
||||
if required:
|
||||
raise ValueError("Missing parameter: {}".format(last_index))
|
||||
return default
|
||||
|
||||
return d[last_index]
|
||||
|
||||
|
||||
def get_prepared_data_filename(prepared_data_part, dataset_name):
|
||||
r"""
|
||||
Build the full filename of the prepared data part
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prepared_data_part : string
|
||||
Part of the prepared data
|
||||
dataset_name : string
|
||||
Name of the dataset
|
||||
|
||||
Returns
|
||||
-------
|
||||
string
|
||||
The full filename for the prepared file
|
||||
"""
|
||||
template = s.PREPARED_DATA_PARTS[prepared_data_part]
|
||||
if "<POSTFIX>" in template:
|
||||
template = template.replace("<POSTFIX>", dataset_name)
|
||||
return template
|
||||
|
||||
|
||||
def create_dataset_and_model(config, purpose, fixed_padding=False):
|
||||
r"""
|
||||
Gets a model from configuration
|
||||
|
||||
Parameters
|
||||
---------
|
||||
config : Dictionary
|
||||
The configuration of the model
|
||||
purpose : string
|
||||
One of "train", "eval", "predict"
|
||||
fixed_padding : bool
|
||||
Parameter passed to the constructor of the DataLoader
|
||||
|
||||
Returns
|
||||
-------
|
||||
In case a Model cannot be initialized return None, None, None. Otherwise:
|
||||
|
||||
device : selected device
|
||||
dataset : Instance of the DataLoader
|
||||
model : Instance of the model
|
||||
"""
|
||||
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
|
||||
|
||||
model_type = config["model"]["type"]
|
||||
model = None
|
||||
|
||||
# Get env vars:
|
||||
use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
|
||||
use_cuda_only = not use_cpu_only
|
||||
|
||||
# Use the cpu for the evaluation
|
||||
device = "cpu" # Default, run on CPU
|
||||
num_gpus = torch.cuda.device_count() # Check if GPU is available
|
||||
if use_cuda_only:
|
||||
device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Create the DataLoader
|
||||
# loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
|
||||
dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
|
||||
dataset.set_device(device)
|
||||
dataset_val = None
|
||||
if config["train"]["validation"] and purpose == "train":
|
||||
dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
|
||||
dataset_val.set_device(device)
|
||||
if model_type == "TableModel04_rs":
|
||||
from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
|
||||
TableModel04_rs,
|
||||
)
|
||||
# Find the model class and create an instance of it
|
||||
for candidate in BaseModel.__subclasses__():
|
||||
if candidate.__name__ == model_type:
|
||||
init_data = dataset.get_init_data()
|
||||
model = candidate(config, init_data, purpose, device)
|
||||
|
||||
if model is None:
|
||||
logger.warn("Not found model: " + str(model_type))
|
||||
return None, None, None
|
||||
|
||||
logger.info("Found model: " + str(model_type))
|
||||
|
||||
if purpose == s.PREDICT_PURPOSE:
|
||||
return device, dataset, model
|
||||
else:
|
||||
return device, dataset, dataset_val, model
|
||||
@@ -0,0 +1,504 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
import docling_ibm_models.tableformer.data_management.transforms as T
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
class DataTransformer:
|
||||
r"""
|
||||
Data transformations for the images and bboxes
|
||||
|
||||
Check the "help" fields inside the config file for an explanation of each parameter
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
print("DataTransformer Init!")
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def append_id(self, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
return "{name}_{uid}{ext}".format(name=name, uid="resized", ext=ext)
|
||||
|
||||
def load_image(self, img_fn):
|
||||
r"""
|
||||
Load an image from the disk
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img_fn: The filename of the image
|
||||
|
||||
Returns
|
||||
-------
|
||||
PIL image object
|
||||
"""
|
||||
|
||||
img = Image.open(img_fn)
|
||||
return img
|
||||
|
||||
def load_image_cv2(self, img_fn):
|
||||
r"""Load an image from the disk
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img_fn: The filename of the image
|
||||
|
||||
Returns
|
||||
-------
|
||||
CV2 image object
|
||||
"""
|
||||
img = cv2.imread(img_fn)
|
||||
return img
|
||||
|
||||
def save_image(self, img, img_fn):
|
||||
img.save(self.append_id(img_fn))
|
||||
|
||||
def renderbboxes(self, img, bboxes):
|
||||
draw_img = ImageDraw.Draw(img)
|
||||
for i in range(len(bboxes)):
|
||||
draw_img.rectangle(bboxes[i], fill=None, outline=(255, 0, 0))
|
||||
return img
|
||||
|
||||
def get_dataset_settings(self):
|
||||
dataset = {}
|
||||
debug = {"save_debug_images": False}
|
||||
|
||||
if "dataset" in self._config:
|
||||
dataset = self._config["dataset"]
|
||||
if "debug" in self._config:
|
||||
debug = self._config["debug"]
|
||||
|
||||
return dataset, debug
|
||||
|
||||
def _prepare_image_from_file(self, image_fn, bboxes, convert_box=True):
|
||||
r"""
|
||||
Load the image from file and prepare it
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_fn : string
|
||||
Filename to load the image
|
||||
bboxes : dict
|
||||
Bounding boxes of the image
|
||||
convert_box : bool
|
||||
If true the bboxes are converted to xcycwh format
|
||||
|
||||
Returns
|
||||
-------
|
||||
PIL image
|
||||
A PIL image object with the image prepared according to the settings in the config file
|
||||
bboxes : dict
|
||||
Converted bboxes of the image
|
||||
"""
|
||||
im = self.load_image(image_fn)
|
||||
return self._prepare_image(im, bboxes, convert_box, image_fn)
|
||||
|
||||
def _prepare_image(self, im, bboxes, convert_box=True, image_fn=None):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
im : PIL image object
|
||||
bboxes : dict
|
||||
Bounding boxes of the image
|
||||
convert_box : bool
|
||||
If true the bboxes are converted to xcycwh format
|
||||
image_fn : string
|
||||
Filename of the original image or None. It is used to save augmented image for debugging
|
||||
|
||||
Returns
|
||||
-------
|
||||
PIL image
|
||||
A PIL image object with the image prepared according to the settings in the config file
|
||||
bboxes : dict
|
||||
Converted bboxes of the image
|
||||
"""
|
||||
debug_settings = False
|
||||
settings, debug_settings = self.get_dataset_settings()
|
||||
|
||||
desired_size = settings["resized_image"]
|
||||
old_size = im.size
|
||||
|
||||
# Calculate Aspect Ratio if needed
|
||||
if settings["keep_AR"]:
|
||||
ratio = float(desired_size) / max(old_size)
|
||||
else:
|
||||
ratio = 1 # Square image
|
||||
|
||||
new_size = old_size
|
||||
if max(old_size) < desired_size:
|
||||
# Image is smaller than desired
|
||||
# Upscale?
|
||||
if settings["up_scaling_enabled"]:
|
||||
# Calculate new image size, taking into account aspect ratio
|
||||
new_size = tuple([int(x * ratio) for x in old_size])
|
||||
else:
|
||||
new_size = old_size
|
||||
else:
|
||||
if settings["keep_AR"]:
|
||||
new_size = tuple([int(x * ratio) for x in old_size])
|
||||
else:
|
||||
new_size = [desired_size, desired_size]
|
||||
|
||||
if not settings["keep_AR"]:
|
||||
if settings["up_scaling_enabled"]:
|
||||
new_size = [desired_size, desired_size]
|
||||
|
||||
######################################################################################
|
||||
# Use OpenCV to resize the image
|
||||
#
|
||||
# im = im.resize(new_size, Image.ANTIALIAS)
|
||||
|
||||
import cv2
|
||||
|
||||
np_im = np.array(im)
|
||||
np_resized = cv2.resize(np_im, new_size, interpolation=cv2.INTER_LANCZOS4)
|
||||
im = Image.fromarray(np_resized)
|
||||
######################################################################################
|
||||
|
||||
new_bboxes = copy.deepcopy(bboxes)
|
||||
|
||||
# Resize bboxes (in pixels)
|
||||
x_scale = new_size[0] / old_size[0]
|
||||
y_scale = new_size[1] / old_size[1]
|
||||
# loop over bboxes
|
||||
for i in range(len(new_bboxes)):
|
||||
new_bboxes[i][0] = x_scale * bboxes[i][0]
|
||||
new_bboxes[i][1] = y_scale * bboxes[i][1]
|
||||
new_bboxes[i][2] = x_scale * bboxes[i][2]
|
||||
new_bboxes[i][3] = y_scale * bboxes[i][3]
|
||||
|
||||
# Set background color for padding
|
||||
br = settings["padding_color"][0]
|
||||
bg = settings["padding_color"][1]
|
||||
bb = settings["padding_color"][2]
|
||||
bcolor = (br, bg, bb)
|
||||
# Create empty canvas of background color and desired size
|
||||
new_im = Image.new(mode="RGB", size=(desired_size, desired_size), color=bcolor)
|
||||
|
||||
if "grayscale" in settings:
|
||||
if settings["grayscale"]:
|
||||
im = im.convert("LA")
|
||||
|
||||
if settings["padding_mode"] == "frame":
|
||||
# If paddinds are around image, paste resized image in the center
|
||||
x_pad = (desired_size - new_size[0]) // 2
|
||||
y_pad = (desired_size - new_size[1]) // 2
|
||||
# Paste rescaled image
|
||||
new_im.paste(im, (x_pad, y_pad))
|
||||
# Offset (pad) bboxes
|
||||
# loop over bboxes
|
||||
for i in range(len(new_bboxes)):
|
||||
new_bboxes[i][0] += x_pad
|
||||
new_bboxes[i][1] += y_pad
|
||||
new_bboxes[i][2] += x_pad
|
||||
new_bboxes[i][3] += y_pad
|
||||
else:
|
||||
# Otherwise paste in the 0,0 coordinates
|
||||
new_im.paste(im, (0, 0))
|
||||
|
||||
if debug_settings:
|
||||
if debug_settings["save_debug_images"]:
|
||||
aug_im = self.renderbboxes(new_im, bboxes)
|
||||
if "grayscale" in settings:
|
||||
if settings["grayscale"]:
|
||||
aug_im = aug_im.convert("LA")
|
||||
self.save_image(aug_im, image_fn)
|
||||
if convert_box:
|
||||
bboxes = self.xyxy_to_xcycwh(new_bboxes, desired_size)
|
||||
return new_im, bboxes
|
||||
|
||||
# convert bboxes from [x1, y1, x2, y2] format to [xc, yc, w, h] format
|
||||
def xyxy_to_xcycwh(self, bboxes, size):
|
||||
# Use the "dataset.bbox_format" parameter to decide which bbox format to use
|
||||
bbox_format = self._config["dataset"].get("bbox_format", "4plet")
|
||||
|
||||
conv_bboxes = []
|
||||
for i in range(len(bboxes)):
|
||||
x1 = bboxes[i][0] / size # X1
|
||||
y1 = bboxes[i][1] / size # Y1
|
||||
x2 = bboxes[i][2] / size # X2
|
||||
y2 = bboxes[i][3] / size # Y2
|
||||
xc = (x1 + x2) / 2
|
||||
yc = (y1 + y2) / 2
|
||||
bw = abs(x2 - x1)
|
||||
bh = abs(y2 - y1)
|
||||
|
||||
if bbox_format == "5plet":
|
||||
cls = bboxes[i][4]
|
||||
conv_bboxes.append([xc, yc, bw, bh, cls])
|
||||
else:
|
||||
conv_bboxes.append([xc, yc, bw, bh])
|
||||
|
||||
# conv_bboxes = bboxes
|
||||
return conv_bboxes
|
||||
|
||||
def rescale_in_memory(self, image, normalization):
|
||||
r"""
|
||||
Receive image and escale it in memory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : PIL image
|
||||
The image data to rescale
|
||||
normalization : dictionary
|
||||
The normalization information with the format:
|
||||
"state": "true or false if image normalization is to be enabled",
|
||||
"mean": "mean values to use if state is true",
|
||||
"std": "std values to use if state is true"
|
||||
Returns
|
||||
-------
|
||||
npimgc : FloatTensor
|
||||
The loaded and properly transformed image data
|
||||
"""
|
||||
settings, debug_settings = self.get_dataset_settings()
|
||||
new_image, _ = self._prepare_image(image, {}, convert_box=False, image_fn=None)
|
||||
# Convert to nparray
|
||||
npimg = np.asarray(new_image) # (width, height, channels)
|
||||
|
||||
# Convert to float?
|
||||
npimgc = npimg.copy()
|
||||
|
||||
# Transpose numpy array (image)
|
||||
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
|
||||
npimgc = torch.FloatTensor(npimgc / 255.0)
|
||||
|
||||
if normalization:
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=self._config["dataset"]["image_normalization"]["mean"],
|
||||
std=self._config["dataset"]["image_normalization"]["std"],
|
||||
)
|
||||
]
|
||||
)
|
||||
npimgc = transform(npimgc)
|
||||
|
||||
return npimgc
|
||||
|
||||
def _rescale(self, image_fn, bboxes, normalization):
|
||||
r"""
|
||||
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
||||
from the config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_fn: full image file name
|
||||
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
||||
bottom-left points
|
||||
statistics: Dictionary with statistics over the whole image dataset.
|
||||
The keys are: "mean", "variance", "std" and each value is a list with the
|
||||
coresponding statistical value for each channel. Normally there are 3
|
||||
channels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
npimgc : FloatTensor
|
||||
The loaded and properly transformed image data
|
||||
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
|
||||
center, w, h the width and height of the bbox and all are normalized to the
|
||||
scaled size of the image
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
In case the configuration and the image dimensions make it impossible to rescale the
|
||||
image throw a ValueError exception
|
||||
"""
|
||||
settings, debug_settings = self.get_dataset_settings()
|
||||
# new_image is a PIL object
|
||||
new_image, new_bboxes = self._prepare_image_from_file(image_fn, bboxes)
|
||||
# Convert to nparray
|
||||
npimg = np.asarray(new_image) # (width, height, channels)
|
||||
# Convert to float?
|
||||
npimgc = npimg.copy()
|
||||
# Transpose numpy array (image)
|
||||
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
|
||||
npimgc = torch.FloatTensor(npimgc / 255.0)
|
||||
|
||||
if normalization:
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=self._config["dataset"]["image_normalization"]["mean"],
|
||||
std=self._config["dataset"]["image_normalization"]["std"],
|
||||
)
|
||||
]
|
||||
)
|
||||
npimgc = transform(npimgc)
|
||||
return npimgc, new_bboxes
|
||||
|
||||
def rescale_old(self, image, bboxes, statistics=None):
|
||||
r"""
|
||||
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
||||
from the config
|
||||
|
||||
Input:
|
||||
image: np array (channels, width, height)
|
||||
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
||||
bottom-left points
|
||||
statistics: Dictionary with statistics over the whole image dataset.
|
||||
The keys are: "mean", "variance", "std" and each value is a list with the
|
||||
coresponding statistical value for each channel. Normally there are 3
|
||||
channels.
|
||||
|
||||
Output:
|
||||
image: np array (channels, resized_image, resized_image)
|
||||
bboxes: List with bboxes in the format (xc, yc, w, h) where xc, yc are the coords of the
|
||||
center, w, h the width and height of the bbox and all are normalized to the
|
||||
scaled size of the image
|
||||
|
||||
Exceptions:
|
||||
In case the configuration and the image dimensions make it impossible to rescale the
|
||||
image throw a ValueError exception
|
||||
"""
|
||||
image_size = 448
|
||||
# Convert the image to (width, height, channels)
|
||||
image = image.transpose(1, 2, 0)
|
||||
|
||||
# Convert to PIL format and resize
|
||||
image = Image.fromarray(image)
|
||||
image = image.resize((image_size, image_size), Image.ANTIALIAS)
|
||||
return image, bboxes
|
||||
|
||||
def rescale_batch(self, images, bboxes, statistics=None):
|
||||
r"""
|
||||
Rescale, resize, pad the given batch of images and its associated bboxes according to the
|
||||
settings from the config.
|
||||
|
||||
Input:
|
||||
images: np array (batch_size, channels, width, height)
|
||||
bboxes:
|
||||
statistics: Dictionary with statistics over the whole image dataset.
|
||||
The keys are: "mean", "variance", "std" and each value is a list with the
|
||||
coresponding statistical value for each channel. Normally there are 3
|
||||
channels.
|
||||
|
||||
Output:
|
||||
image batch: np array (batch_size, channels, resized_image, resized_image)
|
||||
bboxes:
|
||||
|
||||
Exceptions:
|
||||
In case the configuration and the image dimensions make it impossible to rescale the
|
||||
image throw a ValueError exception
|
||||
"""
|
||||
pass
|
||||
|
||||
def sample_preprocessor(self, image_fn, bboxes, purpose, table_bboxes=None):
|
||||
r"""
|
||||
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
||||
from the config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_fn: full image file name
|
||||
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
||||
bottom-left points
|
||||
statistics: Dictionary with statistics over the whole image dataset.
|
||||
The keys are: "mean", "variance", "std" and each value is a list with the
|
||||
coresponding statistical value for each channel. Normally there are 3
|
||||
channels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
npimgc : FloatTensor
|
||||
The loaded and properly transformed image data
|
||||
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
|
||||
center, w, h the width and height of the bbox and all are normalized to the
|
||||
scaled size of the image
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
In case the configuration and the image dimensions make it impossible to rescale the
|
||||
image throw a ValueError exception
|
||||
"""
|
||||
settings, debug_settings = self.get_dataset_settings()
|
||||
img = self.load_image_cv2(image_fn)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
target = {
|
||||
"size": [img.shape[1], img.shape[2]],
|
||||
"boxes": (
|
||||
torch.from_numpy(np.array(bboxes)[:, :4])
|
||||
if purpose != s.PREDICT_PURPOSE
|
||||
else None
|
||||
),
|
||||
"classes": (
|
||||
np.array(bboxes)[:, -1] if purpose != s.PREDICT_PURPOSE else None
|
||||
),
|
||||
"area": img.shape[1] * img.shape[2],
|
||||
}
|
||||
|
||||
optional_transforms = [T.NoTransformation()]
|
||||
|
||||
# Necessary preprocessing ends here, experimental options begin below.
|
||||
# DETR format, might be necessary to keep this structure to share other functions used by
|
||||
# the community
|
||||
|
||||
if purpose == s.TRAIN_PURPOSE:
|
||||
if self._config["dataset"]["color_jitter"]:
|
||||
jitter = T.ColorJitter(
|
||||
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
|
||||
)
|
||||
optional_transforms.append(jitter)
|
||||
|
||||
if self._config["dataset"]["rand_pad"]:
|
||||
pad_val = random.randint(0, 50)
|
||||
rand_pad = T.RandomPad(pad_val)
|
||||
optional_transforms.append(rand_pad)
|
||||
|
||||
if table_bboxes is not None:
|
||||
if self._config["dataset"]["rand_crop"]:
|
||||
w_, h_, _ = img.shape[0], img.shape[1], img.shape[2]
|
||||
w_c, h_c = table_bboxes[0], table_bboxes[1]
|
||||
f_w, f_h = random.randint(0, w_c), random.randint(0, h_c)
|
||||
rand_crop = T.RandomCrop((w_, h_), (f_w, f_h))
|
||||
optional_transforms.append(rand_crop)
|
||||
|
||||
# transform_opt = random.choice(optional_transforms)
|
||||
normalize = T.Normalize(
|
||||
mean=self._config["dataset"]["image_normalization"]["mean"],
|
||||
std=self._config["dataset"]["image_normalization"]["std"],
|
||||
)
|
||||
resized_size = self._config["dataset"]["resized_image"]
|
||||
resize = T.Resize([resized_size, resized_size])
|
||||
|
||||
transformations = T.RandomChoice(optional_transforms)
|
||||
|
||||
img, target = transformations(img, target)
|
||||
img, target = normalize(img, target)
|
||||
img, target = resize(img, target)
|
||||
|
||||
img = img.transpose(2, 1, 0) # (channels, width, height)
|
||||
img = torch.FloatTensor(img / 255.0)
|
||||
bboxes_ = target["boxes"]
|
||||
classes_ = target["classes"]
|
||||
desired_size = img.shape[1]
|
||||
|
||||
if purpose != s.PREDICT_PURPOSE:
|
||||
bboxes_ = np.concatenate(
|
||||
(bboxes_, np.expand_dims(classes_, axis=1)), axis=1
|
||||
)
|
||||
bboxes_ = self.xyxy_to_xcycwh(bboxes_, desired_size)
|
||||
|
||||
return img, bboxes_
|
||||
@@ -0,0 +1,574 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import numbers
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.transforms import functional
|
||||
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
INTER_MODE = {
|
||||
"NEAREST": cv2.INTER_NEAREST,
|
||||
"BILINEAR": cv2.INTER_LINEAR,
|
||||
"BICUBIC": cv2.INTER_CUBIC,
|
||||
}
|
||||
|
||||
PAD_MOD = {
|
||||
"constant": cv2.BORDER_CONSTANT,
|
||||
"edge": cv2.BORDER_REPLICATE,
|
||||
"reflect": cv2.BORDER_DEFAULT,
|
||||
"symmetric": cv2.BORDER_REFLECT,
|
||||
}
|
||||
|
||||
|
||||
def _is_tensor_image(img):
|
||||
return torch.is_tensor(img) and img.ndimension() == 3
|
||||
|
||||
|
||||
def _is_numpy_image(img):
|
||||
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||
|
||||
|
||||
def to_tensor(pic):
|
||||
"""Converts a numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
||||
Args:
|
||||
pic (np.ndarray, torch.Tensor): Image to be converted to tensor, (H x W x C[RGB]).
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
|
||||
if _is_numpy_image(pic):
|
||||
if len(pic.shape) == 2:
|
||||
pic = cv2.cvtColor(pic, cv2.COLOR_GRAY2RGB)
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(img, torch.ByteTensor) or img.max() > 1:
|
||||
return img.float().div(255)
|
||||
else:
|
||||
return img
|
||||
elif _is_tensor_image(pic):
|
||||
return pic
|
||||
|
||||
else:
|
||||
try:
|
||||
return to_tensor(np.array(pic))
|
||||
except Exception:
|
||||
raise TypeError("pic should be ndarray. Got {}".format(type(pic)))
|
||||
|
||||
|
||||
def to_cv_image(pic, mode=None):
|
||||
"""Convert a tensor to an ndarray.
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
mode (str): color space and pixel depth of input data (optional)
|
||||
for example: cv2.COLOR_RGB2BGR.
|
||||
Returns:
|
||||
np.array: Image converted to PIL Image.
|
||||
"""
|
||||
if not (_is_numpy_image(pic) or _is_tensor_image(pic)):
|
||||
raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic)))
|
||||
|
||||
npimg = pic
|
||||
if isinstance(pic, torch.FloatTensor):
|
||||
pic = pic.mul(255).byte()
|
||||
if torch.is_tensor(pic):
|
||||
npimg = np.squeeze(np.transpose(pic.numpy(), (1, 2, 0)))
|
||||
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray")
|
||||
if mode is None:
|
||||
return npimg
|
||||
|
||||
else:
|
||||
return cv2.cvtColor(npimg, mode)
|
||||
|
||||
|
||||
def normalize(tensor, mean, std):
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
See ``Normalize`` for more details.
|
||||
Args:
|
||||
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channely.
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
if _is_tensor_image(tensor):
|
||||
for t, m, s in zip(tensor, mean, std, strict=False):
|
||||
t.sub_(m).div_(s)
|
||||
return tensor
|
||||
elif _is_numpy_image(tensor):
|
||||
return (tensor.astype(np.float32) - 255.0 * np.array(mean)) / np.array(std)
|
||||
else:
|
||||
raise RuntimeError("Undefined type")
|
||||
|
||||
|
||||
def resize(img, size, interpolation="BILINEAR"):
|
||||
"""Resize the input CV Image to the given size.
|
||||
Args:
|
||||
img (np.ndarray): Image to be resized.
|
||||
size (tuple or int): Desired output size. If size is a sequence like
|
||||
(h, w), the output size will be matched to this. If size is an int,
|
||||
the smaller edge of the image will be matched to this number maintaing
|
||||
the aspect ratio. i.e, if height > width, then image will be rescaled to
|
||||
(size * height / width, size)
|
||||
interpolation (str, optional): Desired interpolation. Default is ``BILINEAR``
|
||||
Returns:
|
||||
cv Image: Resized image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
|
||||
raise TypeError("Got inappropriate size arg: {}".format(size))
|
||||
|
||||
if isinstance(size, int):
|
||||
h, w, c = img.shape
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return img
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
return cv2.resize(
|
||||
img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation]
|
||||
)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
return cv2.resize(
|
||||
img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation]
|
||||
)
|
||||
else:
|
||||
oh, ow = size
|
||||
return cv2.resize(
|
||||
img, dsize=(int(ow), int(oh)), interpolation=INTER_MODE[interpolation]
|
||||
)
|
||||
|
||||
|
||||
def to_rgb_bgr(pic):
|
||||
"""Converts a color image stored in BGR sequence to RGB (BGR to RGB)
|
||||
or stored in RGB sequence to BGR (RGB to BGR).
|
||||
Args:
|
||||
pic (np.ndarray, torch.Tensor): Image to be converted, (H x W x 3).
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
|
||||
if _is_numpy_image(pic) or _is_tensor_image(pic):
|
||||
img = pic[:, :, [2, 1, 0]]
|
||||
return img
|
||||
else:
|
||||
try:
|
||||
return to_rgb_bgr(np.array(pic))
|
||||
except Exception:
|
||||
raise TypeError("pic should be numpy.ndarray or torch.Tensor.")
|
||||
|
||||
|
||||
def pad(img, padding, fill=(0, 0, 0), padding_mode="constant"):
|
||||
"""Pad the given CV Image on all sides with speficified padding mode and fill value.
|
||||
Args:
|
||||
img (np.ndarray): Image to be padded.
|
||||
padding (int or tuple): Padding on each border. If a single int is provided this
|
||||
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
||||
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
||||
this is the padding for the left, top, right and bottom borders
|
||||
respectively.
|
||||
fill (int, tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
|
||||
length 3, it is used to fill R, G, B channels respectively.
|
||||
This value is only used when the padding_mode is constant
|
||||
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric.
|
||||
Default is constant.
|
||||
constant: pads with a constant value, this value is specified with fill
|
||||
edge: pads with the last value on the edge of the image
|
||||
reflect: pads with reflection of image (without repeating the last value on the edge)
|
||||
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
||||
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
||||
symmetric: pads with reflection of image (repeating the last value on the edge)
|
||||
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
||||
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
||||
Returns:
|
||||
CV Image: Padded image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
if not isinstance(padding, (numbers.Number, tuple)):
|
||||
raise TypeError("Got inappropriate padding arg")
|
||||
if not isinstance(fill, (numbers.Number, str, tuple)):
|
||||
raise TypeError("Got inappropriate fill arg")
|
||||
if not isinstance(padding_mode, str):
|
||||
raise TypeError("Got inappropriate padding_mode arg")
|
||||
|
||||
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
|
||||
raise ValueError("Padding must be an int or a 2, or 4 element tuple")
|
||||
|
||||
assert padding_mode in [
|
||||
"constant",
|
||||
"edge",
|
||||
"reflect",
|
||||
"symmetric",
|
||||
], "Padding mode should be either constant, edge, reflect or symmetric"
|
||||
|
||||
if isinstance(padding, int):
|
||||
pad_left = pad_right = pad_top = pad_bottom = padding
|
||||
if isinstance(padding, Sequence) and len(padding) == 2:
|
||||
pad_left = pad_right = padding[0]
|
||||
pad_top = pad_bottom = padding[1]
|
||||
if isinstance(padding, Sequence) and len(padding) == 4:
|
||||
pad_left, pad_top, pad_right, pad_bottom = padding
|
||||
|
||||
if isinstance(fill, numbers.Number):
|
||||
fill = (fill,) * (2 * len(img.shape) - 3)
|
||||
|
||||
if padding_mode == "constant":
|
||||
assert (len(fill) == 3 and len(img.shape) == 3) or (
|
||||
len(fill) == 1 and len(img.shape) == 2
|
||||
), "channel of image is {} but length of fill is {}".format(
|
||||
img.shape[-1], len(fill)
|
||||
)
|
||||
|
||||
img = cv2.copyMakeBorder(
|
||||
src=img,
|
||||
top=pad_top,
|
||||
bottom=pad_bottom,
|
||||
left=pad_left,
|
||||
right=pad_right,
|
||||
borderType=PAD_MOD[padding_mode],
|
||||
value=fill,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def crop(img, x, y, h, w):
|
||||
"""Crop the given CV Image.
|
||||
Args:
|
||||
img (np.ndarray): Image to be cropped.
|
||||
x: Upper pixel coordinate.
|
||||
y: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
Returns:
|
||||
CV Image: Cropped image.
|
||||
"""
|
||||
assert _is_numpy_image(img), "img should be CV Image. Got {}".format(type(img))
|
||||
assert h > 0 and w > 0, "h={} and w={} should greater than 0".format(h, w)
|
||||
|
||||
x1, y1, x2, y2 = round(x), round(y), round(x + h), round(y + w)
|
||||
|
||||
# try:
|
||||
# check_point1 = img[x1, y1, ...]
|
||||
# check_point2 = img[x2-1, y2-1, ...]
|
||||
# except IndexError:
|
||||
# img = cv2.copyMakeBorder(img, - min(0, x1), max(x2 - img.shape[0], 0),
|
||||
# -min(0, y1), max(y2 - img.shape[1], 0),
|
||||
# cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
||||
# y2 += -min(0, y1)
|
||||
# y1 += -min(0, y1)
|
||||
# x2 += -min(0, x1)
|
||||
# x1 += -min(0, x1)
|
||||
#
|
||||
# finally:
|
||||
# return img[x1:x2, y1:y2, ...].copy()
|
||||
return img[x1:x2, y1:y2, ...].copy()
|
||||
|
||||
|
||||
def center_crop(img, output_size):
|
||||
if isinstance(output_size, numbers.Number):
|
||||
output_size = (int(output_size), int(output_size))
|
||||
h, w, _ = img.shape
|
||||
th, tw = output_size
|
||||
i = int(round((h - th) * 0.5))
|
||||
j = int(round((w - tw) * 0.5))
|
||||
return crop(img, i, j, th, tw)
|
||||
|
||||
|
||||
def resized_crop(img, i, j, h, w, size, interpolation="BILINEAR"):
|
||||
"""Crop the given CV Image and resize it to desired size. Notably used in RandomResizedCrop.
|
||||
Args:
|
||||
img (np.ndarray): Image to be cropped.
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
size (sequence or int): Desired output size. Same semantics as ``scale``.
|
||||
interpolation (str, optional): Desired interpolation. Default is
|
||||
``BILINEAR``.
|
||||
Returns:
|
||||
np.ndarray: Cropped image.
|
||||
"""
|
||||
assert _is_numpy_image(img), "img should be CV Image"
|
||||
img = crop(img, i, j, h, w)
|
||||
img = resize(img, size, interpolation)
|
||||
return img
|
||||
|
||||
|
||||
def hflip(img):
|
||||
"""Horizontally flip the given PIL Image.
|
||||
Args:
|
||||
img (np.ndarray): Image to be flipped.
|
||||
Returns:
|
||||
np.ndarray: Horizontall flipped image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
return cv2.flip(img, 1)
|
||||
|
||||
|
||||
def vflip(img):
|
||||
"""Vertically flip the given PIL Image.
|
||||
Args:
|
||||
img (CV Image): Image to be flipped.
|
||||
Returns:
|
||||
PIL Image: Vertically flipped image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
|
||||
|
||||
return cv2.flip(img, 0)
|
||||
|
||||
|
||||
def five_crop(img, size):
|
||||
"""Crop the given CV Image into four corners and the central crop.
|
||||
.. Note::
|
||||
This transform returns a tuple of images and there may be a
|
||||
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
Returns:
|
||||
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
|
||||
top right, bottom left, bottom right and center crop.
|
||||
"""
|
||||
if isinstance(size, numbers.Number):
|
||||
size = (int(size), int(size))
|
||||
else:
|
||||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
|
||||
h, w, _ = img.shape
|
||||
crop_h, crop_w = size
|
||||
if crop_w > w or crop_h > h:
|
||||
raise ValueError(
|
||||
"Requested crop size {} is bigger than input size {}".format(size, (h, w))
|
||||
)
|
||||
tl = crop(img, 0, 0, crop_h, crop_w)
|
||||
tr = crop(img, 0, w - crop_w, crop_h, crop_w)
|
||||
bl = crop(img, h - crop_h, 0, crop_h, crop_w)
|
||||
br = crop(img, h - crop_h, w - crop_w, crop_h, crop_w)
|
||||
center = center_crop(img, (crop_h, crop_w))
|
||||
return (tl, tr, bl, br, center)
|
||||
|
||||
|
||||
def ten_crop(img, size, vertical_flip=False):
|
||||
"""Crop the given CV Image into four corners and the central crop plus the
|
||||
flipped version of these (horizontal flipping is used by default).
|
||||
.. Note::
|
||||
This transform returns a tuple of images and there may be a
|
||||
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
vertical_flip (bool): Use vertical flipping instead of horizontal
|
||||
Returns:
|
||||
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
|
||||
br_flip, center_flip) corresponding top left, top right,
|
||||
bottom left, bottom right and center crop and same for the
|
||||
flipped image.
|
||||
"""
|
||||
if isinstance(size, numbers.Number):
|
||||
size = (int(size), int(size))
|
||||
else:
|
||||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
|
||||
first_five = five_crop(img, size)
|
||||
|
||||
if vertical_flip:
|
||||
img = vflip(img)
|
||||
else:
|
||||
img = hflip(img)
|
||||
|
||||
second_five = five_crop(img, size)
|
||||
return first_five + second_five
|
||||
|
||||
|
||||
def adjust_brightness(img, brightness_factor):
|
||||
"""Adjust brightness of an Image.
|
||||
Args:
|
||||
img (np.ndarray): CV Image to be adjusted.
|
||||
brightness_factor (float): How much to adjust the brightness. Can be
|
||||
any non negative number. 0 gives a black image, 1 gives the
|
||||
original image while 2 increases the brightness by a factor of 2.
|
||||
Returns:
|
||||
np.ndarray: Brightness adjusted image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
im = img.astype(np.float32) * brightness_factor
|
||||
im = im.clip(min=0, max=255)
|
||||
return im.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_contrast(img, contrast_factor):
|
||||
"""Adjust contrast of an Image.
|
||||
Args:
|
||||
img (np.ndarray): CV Image to be adjusted.
|
||||
contrast_factor (float): How much to adjust the contrast. Can be any
|
||||
non negative number. 0 gives a solid gray image, 1 gives the
|
||||
original image while 2 increases the contrast by a factor of 2.
|
||||
Returns:
|
||||
np.ndarray: Contrast adjusted image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
im = img.astype(np.float32)
|
||||
mean = round(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY).mean())
|
||||
im = (1 - contrast_factor) * mean + contrast_factor * im
|
||||
im = im.clip(min=0, max=255)
|
||||
return im.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_saturation(img, saturation_factor):
|
||||
"""Adjust color saturation of an image.
|
||||
Args:
|
||||
img (np.ndarray): CV Image to be adjusted.
|
||||
saturation_factor (float): How much to adjust the saturation. 0 will
|
||||
give a gray image, 1 will give the original image while
|
||||
2 will enhance the saturation by a factor of 2.
|
||||
Returns:
|
||||
np.ndarray: Saturation adjusted image.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
|
||||
|
||||
im = img.astype(np.float32)
|
||||
degenerate = cv2.cvtColor(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
|
||||
im = (1 - saturation_factor) * degenerate + saturation_factor * im
|
||||
im = im.clip(min=0, max=255)
|
||||
return im.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_hue(img, hue_factor):
|
||||
"""Adjust hue of an image.
|
||||
The image hue is adjusted by converting the image to HSV and
|
||||
cyclically shifting the intensities in the hue channel (H).
|
||||
The image is then converted back to original image mode.
|
||||
`hue_factor` is the amount of shift in H channel and must be in the
|
||||
interval `[-0.5, 0.5]`.
|
||||
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
||||
Args:
|
||||
img (np.ndarray): CV Image to be adjusted.
|
||||
hue_factor (float): How much to shift the hue channel. Should be in
|
||||
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
||||
HSV space in positive and negative direction respectively.
|
||||
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||
with complementary colors while 0 gives the original image.
|
||||
Returns:
|
||||
np.ndarray: Hue adjusted image.
|
||||
"""
|
||||
if not (-0.5 <= hue_factor <= 0.5):
|
||||
raise ValueError("hue_factor is not in [-0.5, 0.5].")
|
||||
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
im = img.astype(np.uint8)
|
||||
hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV_FULL)
|
||||
hsv[..., 0] += np.uint8(hue_factor * 255)
|
||||
|
||||
im = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB_FULL)
|
||||
return im.astype(img.dtype)
|
||||
|
||||
|
||||
def adjust_gamma(img, gamma, gain=1):
|
||||
"""Perform gamma correction on an image.
|
||||
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
||||
based on the following equation:
|
||||
I_out = 255 * gain * ((I_in / 255) ** gamma)
|
||||
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
|
||||
Args:
|
||||
img (np.ndarray): CV Image to be adjusted.
|
||||
gamma (float): Non negative real number. gamma larger than 1 make the
|
||||
shadows darker, while gamma smaller than 1 make dark regions
|
||||
lighter.
|
||||
gain (float): The constant multiplier.
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
if gamma < 0:
|
||||
raise ValueError("Gamma should be a non-negative real number")
|
||||
|
||||
im = img.astype(np.float32)
|
||||
im = 255.0 * gain * np.power(im / 255.0, gamma)
|
||||
im = im.clip(min=0.0, max=255.0)
|
||||
return im.astype(img.dtype)
|
||||
|
||||
|
||||
def to_grayscale(img, num_output_channels=1):
|
||||
"""Convert image to grayscale version of image.
|
||||
Args:
|
||||
img (np.ndarray): Image to be converted to grayscale.
|
||||
Returns:
|
||||
CV Image: Grayscale version of the image.
|
||||
if num_output_channels == 1 : returned image is single channel
|
||||
if num_output_channels == 3 : returned image is 3 channel with r == g == b
|
||||
"""
|
||||
if not _is_numpy_image(img):
|
||||
raise TypeError("img should be CV Image. Got {}".format(type(img)))
|
||||
|
||||
if num_output_channels == 1:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
elif num_output_channels == 3:
|
||||
img = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
raise ValueError("num_output_channels should be either 1 or 3")
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def gaussian_noise(img: np.ndarray, mean, std):
|
||||
imgtype = img.dtype
|
||||
gauss = np.random.normal(mean, std, img.shape).astype(np.float32)
|
||||
noisy = np.clip((1 + gauss) * img.astype(np.float32), 0, 255)
|
||||
return noisy.astype(imgtype)
|
||||
|
||||
|
||||
def poisson_noise(img):
|
||||
imgtype = img.dtype
|
||||
img = img.astype(np.float32) / 255.0
|
||||
vals = len(np.unique(img))
|
||||
vals = 2 ** np.ceil(np.log2(vals))
|
||||
noisy = 255 * np.clip(
|
||||
np.random.poisson(img.astype(np.float32) * vals) / float(vals), 0, 1
|
||||
)
|
||||
return noisy.astype(imgtype)
|
||||
|
||||
|
||||
def salt_and_pepper(img, prob=0.01):
|
||||
"""Adds "Salt & Pepper" noise to an image.
|
||||
prob: probability (threshold) that controls level of noise
|
||||
"""
|
||||
imgtype = img.dtype
|
||||
rnd = np.random.rand(img.shape[0], img.shape[1])
|
||||
noisy = img.copy()
|
||||
noisy[rnd < prob / 2] = 0.0
|
||||
noisy[rnd > 1 - prob / 2] = 255.0
|
||||
return noisy.astype(imgtype)
|
||||
|
||||
|
||||
def cv_transform(img):
|
||||
img = salt_and_pepper(img)
|
||||
return to_tensor(img)
|
||||
|
||||
|
||||
def pil_transform(img):
|
||||
return functional.to_tensor(img)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,596 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
import docling_ibm_models.tableformer.otsl as otsl
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
# LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
LOG_LEVEL = logging.WARN
|
||||
|
||||
# Cell labels
|
||||
BODY = "body"
|
||||
COL_HEADER = "col_header"
|
||||
MULTI_COL_HEADER = "multi_col_header"
|
||||
MULTI_ROW_HEADER = "multi_row_header"
|
||||
MULTI_ROW = "multi_row"
|
||||
MULTI_COL = "multi_col"
|
||||
|
||||
|
||||
def validate_bboxes_page(bboxes):
|
||||
r"""
|
||||
Useful function for Debugging
|
||||
|
||||
Validate that the bboxes have a positive area in the page coordinate system
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bboxes : list of 4
|
||||
Each element of the list is expected to be a bbox in the page coordinates system
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of invalid bboxes.
|
||||
"""
|
||||
invalid_counter = 0
|
||||
for i, bbox in enumerate(bboxes):
|
||||
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
if area < 0:
|
||||
print("Wrong bbox: {} - {}".format(i, bbox))
|
||||
invalid_counter += 1
|
||||
|
||||
if invalid_counter > 0:
|
||||
print("Invalid bboxes in total: {}".format(invalid_counter))
|
||||
return invalid_counter
|
||||
|
||||
|
||||
def find_intersection(b1, b2):
|
||||
r"""
|
||||
Compute the intersection between 2 bboxes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
b1 : list of 4
|
||||
The page x1y1x2y2 coordinates of the bbox
|
||||
b2 : list of 4
|
||||
The page x1y1x2y2 coordinates of the bbox
|
||||
|
||||
Returns
|
||||
-------
|
||||
The bbox of the intersection or None if there is no intersection
|
||||
"""
|
||||
# Check when the bboxes do NOT intersect
|
||||
if b1[2] < b2[0] or b2[2] < b1[0] or b1[1] > b2[3] or b2[1] > b2[3]:
|
||||
return None
|
||||
|
||||
i_bbox = [
|
||||
max(b1[0], b2[0]),
|
||||
max(b1[1], b2[1]),
|
||||
min(b1[2], b2[2]),
|
||||
min(b1[3], b2[3]),
|
||||
]
|
||||
return i_bbox
|
||||
|
||||
|
||||
class CellMatcher:
|
||||
r"""
|
||||
Match the table cells to the pdf page cells.
|
||||
|
||||
NOTICE: PDF page coordinate system vs table coordinate system.
|
||||
In both systems the bboxes are described in as (x1, y1, x2, y2) with the following meaning:
|
||||
|
||||
Page coordinate system:
|
||||
- Origin (0, 0) at the lower-left corner
|
||||
- (x1, y1) the lower left corner of the box
|
||||
- (x2, y2) the upper right corner of the box
|
||||
|
||||
Table coordinate system:
|
||||
- Origin (0, 0) at the upper-left corner
|
||||
- (x1, y1) the upper left corner of the box
|
||||
- (x2, y2) the lower right corner of the box
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
self._iou_thres = config["predict"]["pdf_cell_iou_thres"]
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def match_cells(self, iocr_page, table_bbox, prediction):
|
||||
r"""
|
||||
Convert the tablemodel prediction into the Docling format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iocr_page : dict
|
||||
The original Docling provided table data
|
||||
prediction : dict
|
||||
The dictionary has the keys:
|
||||
"tag_seq": The sequence in indices from the WORDMAP
|
||||
"html_seq": The sequence as html tags
|
||||
"bboxes": The bounding boxes
|
||||
|
||||
Returns
|
||||
-------
|
||||
matching_details : dict
|
||||
Dictionary with all details about the mathings between the table and pdf cells
|
||||
"""
|
||||
pdf_cells = copy.deepcopy(iocr_page["tokens"])
|
||||
for word in pdf_cells:
|
||||
word["bbox"] = [
|
||||
word["bbox"]["l"],
|
||||
word["bbox"]["t"],
|
||||
word["bbox"]["r"],
|
||||
word["bbox"]["b"],
|
||||
]
|
||||
table_bboxes = prediction["bboxes"]
|
||||
table_classes = prediction["classes"]
|
||||
# BBOXES transformed...
|
||||
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
|
||||
|
||||
# Combine the table tags and bboxes into TableCells
|
||||
html_seq = prediction["html_seq"]
|
||||
otsl_seq = prediction["rs_seq"]
|
||||
table_cells = self._build_table_cells(
|
||||
html_seq, otsl_seq, table_bboxes_page, table_classes
|
||||
)
|
||||
matches, matches_counter = self._intersection_over_pdf_match(
|
||||
table_cells, pdf_cells
|
||||
)
|
||||
|
||||
self._log().debug("matches_counter: {}".format(matches_counter))
|
||||
|
||||
# Build output
|
||||
matching_details = {
|
||||
"iou_threshold": self._iou_thres,
|
||||
"table_bbox": table_bbox,
|
||||
"prediction_bboxes_page": table_bboxes_page, # Make easier the comparison with c++
|
||||
"prediction": prediction,
|
||||
"pdf_cells": pdf_cells,
|
||||
"page_height": iocr_page["height"],
|
||||
"page_width": iocr_page["width"],
|
||||
"table_cells": table_cells,
|
||||
"pdf_cells": pdf_cells,
|
||||
"matches": matches,
|
||||
}
|
||||
return matching_details
|
||||
|
||||
def match_cells_dummy(self, iocr_page, table_bbox, prediction):
|
||||
r"""
|
||||
Convert the tablemodel prediction into the Docling format
|
||||
DUMMY version doesn't do matching with text cells, but propagates predicted bboxes,
|
||||
respecting the rest of the format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
iocr_page : dict
|
||||
The original Docling provided table data
|
||||
prediction : dict
|
||||
The dictionary has the keys:
|
||||
"tag_seq": The sequence in indices from the WORDMAP
|
||||
"html_seq": The sequence as html tags
|
||||
"bboxes": The bounding boxes
|
||||
|
||||
Returns
|
||||
-------
|
||||
matching_details : dict
|
||||
Dictionary with all details about the mathings between the table and pdf cells
|
||||
"""
|
||||
pdf_cells = copy.deepcopy(iocr_page["tokens"])
|
||||
for word in pdf_cells:
|
||||
word["bbox"] = [
|
||||
word["bbox"]["l"],
|
||||
word["bbox"]["t"],
|
||||
word["bbox"]["r"],
|
||||
word["bbox"]["b"],
|
||||
]
|
||||
|
||||
table_bboxes = prediction["bboxes"]
|
||||
table_classes = prediction["classes"]
|
||||
# BBOXES transformed...
|
||||
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
|
||||
|
||||
# Combine the table tags and bboxes into TableCells
|
||||
html_seq = prediction["html_seq"]
|
||||
otsl_seq = prediction["rs_seq"]
|
||||
|
||||
table_cells = self._build_table_cells(
|
||||
html_seq, otsl_seq, table_bboxes_page, table_classes
|
||||
)
|
||||
|
||||
# Build output
|
||||
matching_details = {
|
||||
"iou_threshold": self._iou_thres,
|
||||
"table_bbox": table_bbox,
|
||||
"prediction_bboxes_page": table_bboxes_page,
|
||||
"prediction": prediction,
|
||||
"pdf_cells": pdf_cells,
|
||||
"page_height": iocr_page["height"],
|
||||
"page_width": iocr_page["width"],
|
||||
"table_cells": table_cells,
|
||||
"pdf_cells": pdf_cells,
|
||||
"matches": {},
|
||||
}
|
||||
return matching_details
|
||||
|
||||
def _build_table_cells(self, html_seq, otsl_seq, bboxes, table_classes):
|
||||
r"""
|
||||
Combine the tags and bboxes of the table into unified TableCell objects.
|
||||
Each TableCell takes a row_id, column_id index based on the html structure provided by
|
||||
html_seq.
|
||||
It is assumed that the bboxes are in sync with the appearence of the closing </td>
|
||||
|
||||
Parameters
|
||||
----------
|
||||
html_seq : list
|
||||
List of html tags
|
||||
bboxes : list of lists of 4
|
||||
Bboxes for the table cells at the page origin
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dict
|
||||
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
||||
"""
|
||||
table_html_structure = {
|
||||
"html": {"structure": {"tokens": html_seq}},
|
||||
"split": "predict",
|
||||
"filename": "memory",
|
||||
}
|
||||
|
||||
otsl_spans = {}
|
||||
|
||||
# r, o = otsl.html_to_otsl(table, writer, true, extra_debug, include_html)
|
||||
r, o = otsl.html_to_otsl(table_html_structure, None, False, False, True, False)
|
||||
if not r:
|
||||
ermsg = "ERR#: COULD NOT CONVERT TO RS THIS TABLE TO COMPUTE SPANS"
|
||||
print(ermsg)
|
||||
else:
|
||||
otsl_spans = o["otsl_spans"]
|
||||
|
||||
table_cells = []
|
||||
|
||||
# It is assumed that the bboxes appear in sync (at the same order) as the TDs
|
||||
cell_id = 0
|
||||
|
||||
row_id = -1
|
||||
column_id = -1
|
||||
in_header = False
|
||||
in_body = False
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
|
||||
mode = "OTSL"
|
||||
if mode == "HTML":
|
||||
for tag in html_seq:
|
||||
label = None
|
||||
if tag == "<thead>":
|
||||
in_header = True
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
elif tag == "</thead>":
|
||||
in_header = False
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
elif tag == "<tbody>":
|
||||
in_body = True
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
elif tag == "</tbody>":
|
||||
in_body = False
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
elif tag == "<td>" or tag == "<td":
|
||||
column_id += 1
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
if tag == "<td":
|
||||
multicol_tag = tag
|
||||
elif tag == "<tr>":
|
||||
row_id += 1
|
||||
column_id = -1
|
||||
multicol_tag = ""
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
elif "colspan" in tag:
|
||||
label = MULTI_COL
|
||||
multicol_tag += tag
|
||||
colspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
|
||||
elif "rowspan" in tag:
|
||||
label = MULTI_ROW
|
||||
multicol_tag += tag
|
||||
rowspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
|
||||
elif tag == "</td>": # Create a TableCell on each closing td
|
||||
if len(multicol_tag) > 0:
|
||||
multicol_tag += tag
|
||||
if in_header:
|
||||
if label is None:
|
||||
label = COL_HEADER
|
||||
elif label == MULTI_COL:
|
||||
label = MULTI_COL_HEADER
|
||||
elif label == MULTI_ROW:
|
||||
label = MULTI_ROW_HEADER
|
||||
if label is None and in_body:
|
||||
label = BODY
|
||||
|
||||
err_mismatch = "Mismatching bboxes with closing TDs {} < {}".format(
|
||||
cell_id, len(bboxes)
|
||||
)
|
||||
assert cell_id < len(bboxes), err_mismatch
|
||||
bbox = bboxes[cell_id]
|
||||
cell_class = table_classes[cell_id]
|
||||
|
||||
table_cell = {}
|
||||
table_cell["cell_id"] = cell_id
|
||||
table_cell["row_id"] = row_id
|
||||
table_cell["column_id"] = column_id
|
||||
table_cell["bbox"] = bbox
|
||||
table_cell["cell_class"] = cell_class
|
||||
table_cell["label"] = label
|
||||
table_cell["multicol_tag"] = multicol_tag
|
||||
if colspan_val > 0:
|
||||
table_cell["colspan_val"] = colspan_val
|
||||
column_id += (
|
||||
colspan_val - 1
|
||||
) # Shift column index to account for span
|
||||
if rowspan_val > 0:
|
||||
table_cell["rowspan_val"] = rowspan_val
|
||||
|
||||
table_cells.append(table_cell)
|
||||
cell_id += 1
|
||||
|
||||
if mode == "OTSL":
|
||||
row_id = 0
|
||||
column_id = 0
|
||||
multicol_tag = ""
|
||||
otsl_line = []
|
||||
cell_id_line = []
|
||||
|
||||
for tag in otsl_seq:
|
||||
otsl_line.append(tag)
|
||||
if tag == "nl":
|
||||
row_id += 1
|
||||
column_id = 0
|
||||
otsl_line = []
|
||||
cell_id_line = []
|
||||
if tag in ["fcel", "ecel", "xcel", "ched", "rhed", "srow"]:
|
||||
cell_id_line.append(cell_id)
|
||||
bbox = [0.0, 0.0, 0.0, 0.0]
|
||||
if cell_id < len(bboxes):
|
||||
bbox = bboxes[cell_id]
|
||||
|
||||
cell_class = 2
|
||||
if cell_id < len(table_classes):
|
||||
cell_class = table_classes[cell_id]
|
||||
label = tag
|
||||
|
||||
table_cell = {}
|
||||
table_cell["cell_id"] = cell_id
|
||||
table_cell["row_id"] = row_id
|
||||
table_cell["column_id"] = column_id
|
||||
table_cell["bbox"] = bbox
|
||||
table_cell["cell_class"] = cell_class
|
||||
table_cell["label"] = label
|
||||
table_cell["multicol_tag"] = multicol_tag
|
||||
|
||||
colspan_val = 0
|
||||
rowspan_val = 0
|
||||
|
||||
if cell_id in otsl_spans:
|
||||
colspan_val = otsl_spans[cell_id][0]
|
||||
rowspan_val = otsl_spans[cell_id][1]
|
||||
if colspan_val > 0:
|
||||
table_cell["colspan_val"] = colspan_val
|
||||
if rowspan_val > 0:
|
||||
table_cell["rowspan_val"] = rowspan_val
|
||||
|
||||
table_cells.append(table_cell)
|
||||
cell_id += 1
|
||||
if tag != "nl":
|
||||
column_id += 1
|
||||
|
||||
return table_cells
|
||||
|
||||
def _translate_bboxes(self, table_bbox, cell_bboxes):
|
||||
r"""
|
||||
Translate table cell bboxes to the lower-left corner of the page.
|
||||
|
||||
The cells of the table are given:
|
||||
- Origin at the top left corner
|
||||
- Point A: Top left corner
|
||||
- Point B: Low right corner
|
||||
- Coordinate values are normalized to the table width/height
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_bbox : list of 4
|
||||
The whole table bbox page coordinates
|
||||
cell_bboxes : list of lists of 4
|
||||
The bboxes of the table cells
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of 4
|
||||
The translated bboxes of the table cells
|
||||
"""
|
||||
W = table_bbox[2] - table_bbox[0]
|
||||
H = table_bbox[3] - table_bbox[1]
|
||||
b = np.asarray(cell_bboxes)
|
||||
t_mask = np.asarray(
|
||||
[table_bbox[0], table_bbox[3], table_bbox[0], table_bbox[3]]
|
||||
)
|
||||
m = np.asarray([W, -H, W, -H])
|
||||
page_bboxes_y_flipped = t_mask + m * b
|
||||
page_bboxes = page_bboxes_y_flipped[:, [0, 3, 2, 1]] # Flip y1' with y2'
|
||||
page_bboxes_list = page_bboxes.tolist()
|
||||
|
||||
t_height = table_bbox[3]
|
||||
page_bboxes_list1 = []
|
||||
for page_bbox in page_bboxes_list:
|
||||
page_bbox1 = [
|
||||
page_bbox[0],
|
||||
t_height - page_bbox[3] + table_bbox[1],
|
||||
page_bbox[2],
|
||||
t_height - page_bbox[1] + table_bbox[1],
|
||||
]
|
||||
page_bboxes_list1.append(page_bbox1)
|
||||
return page_bboxes_list1
|
||||
|
||||
def _intersection_over_pdf_match(self, table_cells, pdf_cells):
|
||||
r"""
|
||||
Compute Intersection between table cells and pdf cells,
|
||||
match 1 pdf cell with highest intersection with only 1 table cell.
|
||||
|
||||
First compute and cache the areas for all involved bboxes.
|
||||
Then compute the pairwise intersections
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_cells : list of dict
|
||||
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
||||
|
||||
pdf_cells : list of dict
|
||||
Each element of the list is a dictionary which should have the keys: "id", "bbox"
|
||||
Returns
|
||||
-------
|
||||
dictionary of lists of table_cells
|
||||
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
|
||||
of the table_cells that fall inside that pdf cell
|
||||
int
|
||||
Number of total matches
|
||||
"""
|
||||
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
|
||||
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
|
||||
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
|
||||
)
|
||||
|
||||
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
|
||||
matches = {}
|
||||
matches_counter = 0
|
||||
|
||||
# Compute Intersections and build matches
|
||||
for i, table_cell in enumerate(table_cells):
|
||||
table_cell_id = table_cell["cell_id"]
|
||||
t_bbox = table_cell["bbox"]
|
||||
|
||||
for j, pdf_cell in enumerate(pdf_cells):
|
||||
pdf_cell_id = pdf_cell["id"]
|
||||
p_bbox = pdf_cell["bbox"]
|
||||
|
||||
# Compute intersection
|
||||
i_bbox = find_intersection(t_bbox, p_bbox)
|
||||
if i_bbox is None:
|
||||
continue
|
||||
|
||||
# Compute IOU and filter on threshold
|
||||
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
|
||||
iopdf = 0
|
||||
if float(pdf_bboxes_areas[j]) > 0:
|
||||
iopdf = i_bbox_area / float(pdf_bboxes_areas[j])
|
||||
|
||||
if iopdf > 0:
|
||||
match = {"table_cell_id": table_cell_id, "iopdf": iopdf}
|
||||
if pdf_cell_id not in matches:
|
||||
matches[pdf_cell_id] = [match]
|
||||
matches_counter += 1
|
||||
else:
|
||||
# Check if the same match was not already counted
|
||||
if match not in matches[pdf_cell_id]:
|
||||
matches[pdf_cell_id].append(match)
|
||||
matches_counter += 1
|
||||
return matches, matches_counter
|
||||
|
||||
def _iou_match(self, table_cells, pdf_cells):
|
||||
r"""
|
||||
Use Intersection over Union to decide the matching between table cells and pdf cells
|
||||
|
||||
First compute and cache the areas for all involved bboxes.
|
||||
Then compute the pairwise intersections and IOUs and keep those pairs that exceed the IOU
|
||||
threshold
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_cells : list of dict
|
||||
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
||||
|
||||
pdf_cells : list of dict
|
||||
Each element of the list is a dictionary which should have the keys: "id", "bbox"
|
||||
Returns
|
||||
-------
|
||||
dictionary of lists of table_cells
|
||||
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
|
||||
of the table_cells that fall inside that pdf cell
|
||||
int
|
||||
Number of total matches
|
||||
"""
|
||||
table_bboxes = np.asarray([t["bbox"] for t in table_cells])
|
||||
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
|
||||
|
||||
# Cache the areas for table bboxes and pdf bboxes
|
||||
table_bboxes_areas = (table_bboxes[:, 2] - table_bboxes[:, 0]) * (
|
||||
table_bboxes[:, 3] - table_bboxes[:, 1]
|
||||
)
|
||||
|
||||
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
|
||||
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
|
||||
)
|
||||
|
||||
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
|
||||
matches = {}
|
||||
matches_counter = 0
|
||||
|
||||
# Compute IOUs and build matches
|
||||
for i, table_cell in enumerate(table_cells):
|
||||
table_cell_id = table_cell["cell_id"]
|
||||
t_bbox = table_cell["bbox"]
|
||||
|
||||
for j, pdf_cell in enumerate(pdf_cells):
|
||||
pdf_cell_id = pdf_cell["id"]
|
||||
pdf_cell_text = pdf_cell["text"]
|
||||
p_bbox = pdf_cell["bbox"]
|
||||
|
||||
# Compute intersection
|
||||
i_bbox = find_intersection(t_bbox, p_bbox)
|
||||
if i_bbox is None:
|
||||
continue
|
||||
|
||||
# Compute IOU and filter on threshold
|
||||
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
|
||||
iou = 0
|
||||
div_area = float(
|
||||
table_bboxes_areas[i] + pdf_bboxes_areas[j] - i_bbox_area
|
||||
)
|
||||
if div_area > 0:
|
||||
iou = i_bbox_area / div_area
|
||||
if iou < self._iou_thres:
|
||||
continue
|
||||
|
||||
if pdf_cell_id not in matches:
|
||||
matches[pdf_cell_id] = []
|
||||
|
||||
match = {
|
||||
"table_cell_id": table_cell_id,
|
||||
"iou": iou,
|
||||
"text": pdf_cell_text,
|
||||
}
|
||||
matches[pdf_cell_id].append(match)
|
||||
matches_counter += 1
|
||||
|
||||
return matches, matches_counter
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,396 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
from __future__ import division
|
||||
|
||||
import collections
|
||||
import numbers
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from docling_ibm_models.tableformer.data_management import functional as F
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
class Lambda(object):
|
||||
"""Apply a user-defined lambda as a transform.
|
||||
Attention: The multiprocessing used in dataloader of pytorch
|
||||
is not friendly with lambda function in Windows
|
||||
Args:
|
||||
lambd (function): Lambda/function to be used for transform.
|
||||
"""
|
||||
|
||||
def __init__(self, lambd):
|
||||
# assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
# if 'Windows' in platform.system():
|
||||
# raise RuntimeError("Can't pickle lambda funciton in windows system")
|
||||
|
||||
def __call__(self, img):
|
||||
return self.lambd(img)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
class RandomTransforms(object):
|
||||
"""Base class for a list of transformations with randomness
|
||||
Args:
|
||||
transforms (list or tuple): list of transformations
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, (list, tuple))
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class RandomChoice(RandomTransforms):
|
||||
"""Apply single transformation randomly picked from a list"""
|
||||
|
||||
def __call__(self, img, target):
|
||||
t = random.choice(self.transforms)
|
||||
return t(img, target)
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
def __init__(self, size, margin_crop):
|
||||
self.size = list(size)
|
||||
self.margin_crop = list(margin_crop)
|
||||
# margin_crop: w, h
|
||||
|
||||
def __call__(self, img, target):
|
||||
# img (w,h,ch)
|
||||
image_height, image_width = img.shape[0], img.shape[1]
|
||||
"""
|
||||
img (np.ndarray): Image to be cropped.
|
||||
x: Upper pixel coordinate.
|
||||
y: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
"""
|
||||
if image_width > 0 and image_height > 0:
|
||||
cropped_image = F.crop(
|
||||
img,
|
||||
self.margin_crop[1],
|
||||
self.margin_crop[0],
|
||||
image_height - (self.margin_crop[1] * 2),
|
||||
image_width - (self.margin_crop[0] * 2),
|
||||
)
|
||||
|
||||
target_ = target.copy()
|
||||
target_["boxes"][:, 0] = target_["boxes"][:, 0] - self.margin_crop[0]
|
||||
target_["boxes"][:, 1] = target_["boxes"][:, 1] - self.margin_crop[1]
|
||||
target_["boxes"][:, 2] = target_["boxes"][:, 2] - self.margin_crop[0]
|
||||
target_["boxes"][:, 3] = target_["boxes"][:, 3] - self.margin_crop[1]
|
||||
else:
|
||||
cropped_image = img
|
||||
return cropped_image, target_
|
||||
|
||||
|
||||
class RandomPad(object):
|
||||
def __init__(self, max_pad):
|
||||
self.max_pad = max_pad
|
||||
|
||||
def __call__(self, img, target):
|
||||
pad_x = random.randint(0, self.max_pad)
|
||||
pad_y = random.randint(0, self.max_pad)
|
||||
pad_x1 = random.randint(0, self.max_pad)
|
||||
pad_y1 = random.randint(0, self.max_pad)
|
||||
img = img.copy()
|
||||
padded_image = F.pad(img, (pad_x, pad_y, pad_x1, pad_y1), fill=(255, 255, 255))
|
||||
target_ = target.copy()
|
||||
if target["boxes"] is not None:
|
||||
target_["boxes"][:, 0] = target_["boxes"][:, 0] + pad_x
|
||||
target_["boxes"][:, 1] = target_["boxes"][:, 1] + pad_y
|
||||
target_["boxes"][:, 2] = target_["boxes"][:, 2] + pad_x
|
||||
target_["boxes"][:, 3] = target_["boxes"][:, 3] + pad_y
|
||||
return padded_image, target_
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
"""Randomly change the brightness, contrast and saturation of an image.
|
||||
Args:
|
||||
brightness (float): How much to jitter brightness. brightness_factor
|
||||
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
||||
contrast (float): How much to jitter contrast. contrast_factor
|
||||
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
||||
saturation (float): How much to jitter saturation. saturation_factor
|
||||
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
||||
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
||||
[-hue, hue]. Should be >=0 and <= 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||
|
||||
assert isinstance(brightness, float) or (
|
||||
isinstance(brightness, collections.Iterable) and len(brightness) == 2
|
||||
)
|
||||
assert isinstance(contrast, float) or (
|
||||
isinstance(contrast, collections.Iterable) and len(contrast) == 2
|
||||
)
|
||||
assert isinstance(saturation, float) or (
|
||||
isinstance(saturation, collections.Iterable) and len(saturation) == 2
|
||||
)
|
||||
assert isinstance(hue, float) or (
|
||||
isinstance(hue, collections.Iterable) and len(hue) == 2
|
||||
)
|
||||
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
|
||||
@staticmethod
|
||||
def get_params(brightness, contrast, saturation, hue):
|
||||
"""Get a randomized transform to be applied on image.
|
||||
Arguments are same as that of __init__.
|
||||
Returns:
|
||||
Transform which randomly adjusts brightness, contrast and
|
||||
saturation in a random order.
|
||||
"""
|
||||
transforms = []
|
||||
|
||||
if isinstance(brightness, numbers.Number):
|
||||
|
||||
if brightness > 0:
|
||||
brightness_factor = random.uniform(
|
||||
max(0, 1 - brightness), 1 + brightness
|
||||
)
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
|
||||
)
|
||||
|
||||
if contrast > 0:
|
||||
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
|
||||
)
|
||||
|
||||
if saturation > 0:
|
||||
saturation_factor = random.uniform(
|
||||
max(0, 1 - saturation), 1 + saturation
|
||||
)
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
|
||||
)
|
||||
|
||||
if hue > 0:
|
||||
hue_factor = random.uniform(-hue, hue)
|
||||
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
||||
|
||||
else:
|
||||
|
||||
if brightness[0] > 0 and brightness[1] > 0:
|
||||
|
||||
brightness_factor = random.uniform(brightness[0], brightness[1])
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
|
||||
)
|
||||
|
||||
if contrast[0] > 0 and contrast[1] > 0:
|
||||
|
||||
contrast_factor = random.uniform(contrast[0], contrast[1])
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
|
||||
)
|
||||
|
||||
if saturation[0] > 0 and saturation[1] > 0:
|
||||
|
||||
saturation_factor = random.uniform(saturation[0], saturation[1])
|
||||
transforms.append(
|
||||
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
|
||||
)
|
||||
|
||||
if hue[0] > 0 and hue[1] > 0:
|
||||
hue_factor = random.uniform(hue[0], hue[1])
|
||||
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
||||
|
||||
random.shuffle(transforms)
|
||||
transform = ComposeSingle(transforms)
|
||||
|
||||
return transform
|
||||
|
||||
def __call__(self, img, target):
|
||||
"""
|
||||
Args:
|
||||
img (np.ndarray): Input image.
|
||||
Returns:
|
||||
np.ndarray: Color jittered image.
|
||||
"""
|
||||
transform = self.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
return transform(img), target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
format_string += "brightness={0}".format(self.brightness)
|
||||
format_string += ", contrast={0}".format(self.contrast)
|
||||
format_string += ", saturation={0}".format(self.saturation)
|
||||
format_string += ", hue={0})".format(self.hue)
|
||||
return format_string
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
||||
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
||||
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
Args:
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channel.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensor, target=None):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
return F.normalize(tensor, self.mean, self.std), target
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "(mean={0}, std={1})".format(
|
||||
self.mean, self.std
|
||||
)
|
||||
|
||||
|
||||
class NoTransformation(object):
|
||||
"""Do Nothing"""
|
||||
|
||||
def __call__(self, img, target):
|
||||
return img, target
|
||||
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img, target):
|
||||
for t in self.transforms:
|
||||
img, target = t(img, target)
|
||||
return img, target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class ComposeSingle(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
for t in self.transforms:
|
||||
img = t(img)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize the input PIL Image to the given size.
|
||||
Args:
|
||||
size (sequence or int): Desired output size. If size is a sequence like
|
||||
(h, w), output size will be matched to this. If size is an int,
|
||||
smaller edge of the image will be matched to this number.
|
||||
i.e, if height > width, then image will be rescaled to
|
||||
(size * height / width, size)
|
||||
interpolation (int, optional): Desired interpolation. Default is
|
||||
``BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(self, size, interpolation="BILINEAR"):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img, target=None):
|
||||
"""
|
||||
Args:
|
||||
img (np.ndarray): Image to be scaled.
|
||||
Returns:
|
||||
np.ndarray: Rescaled image.
|
||||
"""
|
||||
# Resize bboxes (in pixels)
|
||||
x_scale = 0
|
||||
y_scale = 0
|
||||
|
||||
if img.shape[1] > 0:
|
||||
x_scale = self.size[0] / img.shape[1]
|
||||
if img.shape[0] > 0:
|
||||
y_scale = self.size[1] / img.shape[0]
|
||||
|
||||
# loop over bboxes
|
||||
if target is not None:
|
||||
if target["boxes"] is not None:
|
||||
target_ = target.copy()
|
||||
target_["boxes"][:, 0] = x_scale * target_["boxes"][:, 0]
|
||||
target_["boxes"][:, 1] = y_scale * target_["boxes"][:, 1]
|
||||
target_["boxes"][:, 2] = x_scale * target_["boxes"][:, 2]
|
||||
target_["boxes"][:, 3] = y_scale * target_["boxes"][:, 3]
|
||||
return F.resize(img, self.size, self.interpolation), target
|
||||
|
||||
def __repr__(self):
|
||||
interpolate_str = self.interpolation
|
||||
return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
|
||||
self.size, interpolate_str
|
||||
)
|
||||
@@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
r"""
|
||||
BaseModel provides some common functionality for all models:
|
||||
- Saves checkpoint files for each epoch
|
||||
- Loads the model from the best available checkpoint
|
||||
- Save repository branch and commit
|
||||
"""
|
||||
|
||||
def __init__(self, config, init_data, device):
|
||||
r"""
|
||||
Inputs:
|
||||
config: The configuration file
|
||||
init_data: Dictionary with initialization data. This dictionary can be used to pass any
|
||||
kind of initialization data for the models
|
||||
device: The device used to move the tensors of the model
|
||||
"""
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
# Set config and device
|
||||
self._config = config
|
||||
self._init_data = init_data
|
||||
|
||||
self._device = device
|
||||
|
||||
self._save_dir = config["model"]["save_dir"]
|
||||
self._load_checkpoint = None
|
||||
if "load_checkpoint" in config["model"]:
|
||||
self._load_checkpoint = config["model"]["load_checkpoint"]
|
||||
|
||||
self._branch_name = "dev/next"
|
||||
self._commit_sha = "1"
|
||||
|
||||
# Keep a dictionary with the starting times per epoch.
|
||||
# NOTICE: Epochs start from 0
|
||||
self._epoch_start_ts = {0: time.time()}
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, img, max_steps, beam_size, return_attention):
|
||||
pass
|
||||
|
||||
def count_parameters(self):
|
||||
r"""Counts the number of trainable parameters of this model
|
||||
|
||||
Output:
|
||||
num_parameters: number of trainable parameters
|
||||
"""
|
||||
num_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
return num_parameters
|
||||
|
||||
def get_code_version(self):
|
||||
r"""Gets the source control version of this model code
|
||||
|
||||
Returns
|
||||
-------
|
||||
branch_name : str
|
||||
The name of the Git branch of this model code
|
||||
commit_sha : str
|
||||
The unique identifier of the Git commit of this model code
|
||||
"""
|
||||
|
||||
return self._branch_name, self._commit_sha
|
||||
|
||||
def get_save_directory(self):
|
||||
r"""
|
||||
Return the save directory
|
||||
"""
|
||||
return self._save_dir
|
||||
|
||||
def is_saved(self):
|
||||
r"""
|
||||
This method returns True if both conditions are met:
|
||||
1. There is a checkpoint file for the model.
|
||||
2. The checkpoint file corresponds to the last training epoch set in the configuration file.
|
||||
"""
|
||||
# Get the saved_model
|
||||
saved_model, _ = self._load_best_checkpoint()
|
||||
|
||||
if saved_model is None:
|
||||
return False
|
||||
|
||||
epochs = self._config["train"]["epochs"]
|
||||
self._log().debug(
|
||||
"Best epoch in saved model: {}; Number of epochs in config: {}".format(
|
||||
saved_model["epoch"], epochs
|
||||
)
|
||||
)
|
||||
if epochs == saved_model["epoch"] + 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def save(self, epoch=None, optimizers=None, losses=None, model_parameters=None):
|
||||
r"""
|
||||
Save the model data to the disk as a pickle file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epoch: Training epoch
|
||||
optimizers: Dictionary with the optimizers. The key specifies what the optimizer is
|
||||
used for. The 'state_dict' of each optimizer will be saved in the
|
||||
checkpoint file.
|
||||
losses: Dictionary with the losses. The key specifies what the loss is used for. Each
|
||||
value is a list
|
||||
model_parameters: Dictionary with model specific parameters that we need to save in the
|
||||
checkpoint file.
|
||||
Returns
|
||||
-------
|
||||
True if success, False otherwise
|
||||
"""
|
||||
# Get the checkpoint_filename
|
||||
c_filename = self._build_checkpoint_filename(epoch)
|
||||
self._log().debug("Trying to save checkpoint file: {}".format(c_filename))
|
||||
|
||||
# Prepare a dictionary with all data we want to save
|
||||
optimizers_state_dict = None
|
||||
if optimizers is not None:
|
||||
optimizers_state_dict = {k: v.state_dict() for k, v in optimizers.items()}
|
||||
|
||||
model_data = {
|
||||
"model_state_dict": self.state_dict(),
|
||||
"epoch": epoch,
|
||||
"optimizers": optimizers_state_dict,
|
||||
"losses": losses,
|
||||
"model_parameters": model_parameters,
|
||||
}
|
||||
|
||||
# Add the processing time per epoch
|
||||
now = time.time()
|
||||
self._epoch_start_ts[epoch + 1] = now
|
||||
if epoch in self._epoch_start_ts:
|
||||
dt = now - self._epoch_start_ts[epoch]
|
||||
model_data["epoch_start_ts"] = self._epoch_start_ts[epoch]
|
||||
model_data["epoch_dt"] = dt
|
||||
|
||||
# Create the save directory
|
||||
Path(self._save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the model
|
||||
torch.save(model_data, c_filename)
|
||||
|
||||
# Return true if file is present, otherwise false
|
||||
if not os.path.isfile(c_filename):
|
||||
self._log().error("Cannot find the file to save: " + c_filename)
|
||||
return False
|
||||
|
||||
# store code branch name and commit
|
||||
version_file = os.path.join(self._save_dir, "_version")
|
||||
with open(version_file, "w") as text_file:
|
||||
print("Model is using code [commit:branch]", file=text_file)
|
||||
print("{}:{}".format(self._commit_sha, self._branch_name), file=text_file)
|
||||
|
||||
return True
|
||||
|
||||
def load(self, optimizers=None):
|
||||
r"""
|
||||
Load the model data from the disk.
|
||||
The method will iterate over all *.check files and try to load the one from the highest
|
||||
epoch.
|
||||
|
||||
Input:
|
||||
-optimizers: Dictionary with optimizers. If it is not null the keys will be used to
|
||||
associate the corresponding state_dicts from the checkpoint file and update
|
||||
the internal states of the provided optimizers.
|
||||
|
||||
Output:
|
||||
- Success: True/ False
|
||||
- epoch: Loaded epoch or -1 if there are no checkpoint files
|
||||
- optimizers: Dictionary with loaded optimizers or empty dictionary of there is no
|
||||
checkpoint file
|
||||
- losses: Dictionary with loaded losses or empty dictionary of there is no checkpoint
|
||||
file
|
||||
- model_parameters: Dictionary with the model parameters or empty dictionary if there
|
||||
are no checkpoint files
|
||||
"""
|
||||
# Get the saved_model
|
||||
saved_model, _ = self._load_best_checkpoint()
|
||||
|
||||
# Restore the model
|
||||
if saved_model is None:
|
||||
self._log().debug("No saved model checkpoint found")
|
||||
return False, -1, optimizers, {}, {}
|
||||
|
||||
self._log().debug("Loading model from checkpoint file")
|
||||
self.load_state_dict(saved_model["model_state_dict"])
|
||||
|
||||
epoch = 0
|
||||
if "epoch" in saved_model:
|
||||
epoch = saved_model["epoch"]
|
||||
losses = {}
|
||||
if "losses" in saved_model:
|
||||
losses = saved_model["losses"]
|
||||
model_parameters = saved_model["model_parameters"]
|
||||
|
||||
if optimizers is not None:
|
||||
for key, optimizer_state_dict in saved_model["optimizers"].items():
|
||||
optimizers[key].load_state_dict(optimizer_state_dict)
|
||||
|
||||
# Reset the start_ts of the next epoch
|
||||
self._epoch_start_ts[epoch + 1] = time.time()
|
||||
|
||||
return True, epoch, optimizers, losses, model_parameters
|
||||
|
||||
def _load_best_checkpoint(self):
|
||||
r"""
|
||||
If a "load_checkpoint" file has been provided, load this one.
|
||||
Otherwise use the "save_dir" and load the one with the most advanced epoch
|
||||
|
||||
Returns
|
||||
-------
|
||||
saved_model : dictionary
|
||||
Checkpoint file contents generated by torch.load, or None
|
||||
checkpoint_file : string
|
||||
Filename of the loaded checkpoint, or None
|
||||
"""
|
||||
checkpoint_files = []
|
||||
# If a "load_checkpoint" file is provided, try to load it
|
||||
if self._load_checkpoint is not None:
|
||||
if not os.path.exists(self._load_checkpoint):
|
||||
self._log().error(
|
||||
"Cannot load the checkpoint: {}".format(self._load_checkpoint)
|
||||
)
|
||||
return None, None
|
||||
checkpoint_files.append(self._load_checkpoint)
|
||||
else:
|
||||
# Iterate over all check files from the directory by reverse alphabetical order
|
||||
# This will get the biggest epoch first
|
||||
checkpoint_files = glob.glob(os.path.join(self._save_dir, "*.check"))
|
||||
checkpoint_files.sort(reverse=True)
|
||||
|
||||
for checkpoint_file in checkpoint_files:
|
||||
try:
|
||||
# Try to load the file
|
||||
self._log().info(
|
||||
"Loading model checkpoint file: {}".format(checkpoint_file)
|
||||
)
|
||||
saved_model = torch.load(checkpoint_file, map_location=self._device)
|
||||
return saved_model, checkpoint_file
|
||||
except RuntimeError:
|
||||
self._log().error("Cannot load file: {}".format(checkpoint_file))
|
||||
|
||||
return None, None
|
||||
|
||||
def _build_checkpoint_filename(self, epoch):
|
||||
r"""
|
||||
Construct the full path for the filename of this checkpoint
|
||||
"""
|
||||
dataset_name = self._config["dataset"]["name"]
|
||||
model_type = self._config["model"]["type"]
|
||||
model_name = self._config["model"]["name"]
|
||||
filename = "{}_{}_{}_{:03}.check".format(
|
||||
model_type, model_name, dataset_name, epoch
|
||||
)
|
||||
c_filename = os.path.join(self._save_dir, filename)
|
||||
|
||||
return c_filename
|
||||
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
import docling_ibm_models.tableformer.utils.utils as u
|
||||
|
||||
# from scipy.optimize import linear_sum_assignment
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
|
||||
|
||||
class CellAttention(nn.Module):
|
||||
"""
|
||||
Attention Network.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim):
|
||||
"""
|
||||
:param encoder_dim: feature size of encoded images
|
||||
:param tag_decoder_dim: size of tag decoder's RNN
|
||||
:param language_dim: size of language model's RNN
|
||||
:param attention_dim: size of the attention network
|
||||
"""
|
||||
super(CellAttention, self).__init__()
|
||||
# linear layer to transform encoded image
|
||||
self._encoder_att = nn.Linear(encoder_dim, attention_dim)
|
||||
# linear layer to transform tag decoder output
|
||||
self._tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim)
|
||||
# linear layer to transform language models output
|
||||
self._language_att = nn.Linear(language_dim, attention_dim)
|
||||
# linear layer to calculate values to be softmax-ed
|
||||
self._full_att = nn.Linear(attention_dim, 1)
|
||||
self._relu = nn.ReLU()
|
||||
self._softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def forward(self, encoder_out, decoder_hidden, language_out):
|
||||
"""
|
||||
Forward propagation.
|
||||
:param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim)
|
||||
:param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells,
|
||||
tag_decoder_dim)]
|
||||
:param language_out: language model output, a tensor of dimension (num_cells,
|
||||
language_dim)
|
||||
:return: attention weighted encoding, weights
|
||||
"""
|
||||
att1 = self._encoder_att(encoder_out) # (1, num_pixels, attention_dim)
|
||||
att2 = self._tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim)
|
||||
att3 = self._language_att(language_out) # (num_cells, attention_dim)
|
||||
att = self._full_att(
|
||||
self._relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))
|
||||
).squeeze(2)
|
||||
alpha = self._softmax(att) # (num_cells, num_pixels)
|
||||
# (num_cells, encoder_dim)
|
||||
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
|
||||
return attention_weighted_encoding, alpha
|
||||
|
||||
|
||||
class BBoxDecoder(nn.Module):
|
||||
"""
|
||||
CellDecoder generates cell content
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
attention_dim,
|
||||
embed_dim,
|
||||
tag_decoder_dim,
|
||||
decoder_dim,
|
||||
num_classes,
|
||||
encoder_dim=512,
|
||||
dropout=0.5,
|
||||
cnn_layer_stride=1,
|
||||
):
|
||||
"""
|
||||
:param attention_dim: size of attention network
|
||||
:param embed_dim: embedding size
|
||||
:param tag_decoder_dim: size of tag decoder's RNN
|
||||
:param decoder_dim: size of decoder's RNN
|
||||
:param vocab_size: size of vocabulary
|
||||
:param encoder_dim: feature size of encoded images
|
||||
:param dropout: dropout
|
||||
:param mini_batch_size: batch size of cells to reduce GPU memory usage
|
||||
"""
|
||||
super(BBoxDecoder, self).__init__()
|
||||
self._device = device
|
||||
self._encoder_dim = encoder_dim
|
||||
self._attention_dim = attention_dim
|
||||
self._embed_dim = embed_dim
|
||||
self._decoder_dim = decoder_dim
|
||||
self._dropout = dropout
|
||||
self._num_classes = num_classes
|
||||
|
||||
if cnn_layer_stride is not None:
|
||||
self._input_filter = u.resnet_block(stride=cnn_layer_stride)
|
||||
# attention network
|
||||
self._attention = CellAttention(
|
||||
encoder_dim, tag_decoder_dim, decoder_dim, attention_dim
|
||||
)
|
||||
# decoder LSTMCell
|
||||
self._init_h = nn.Linear(encoder_dim, decoder_dim)
|
||||
|
||||
# linear layer to create a sigmoid-activated gate
|
||||
self._f_beta = nn.Linear(decoder_dim, encoder_dim)
|
||||
self._sigmoid = nn.Sigmoid()
|
||||
self._dropout = nn.Dropout(p=self._dropout)
|
||||
self._class_embed = nn.Linear(512, self._num_classes + 1)
|
||||
self._bbox_embed = u.MLP(512, 256, 4, 3)
|
||||
|
||||
def _init_hidden_state(self, encoder_out, batch_size):
|
||||
mean_encoder_out = encoder_out.mean(dim=1)
|
||||
h = self._init_h(mean_encoder_out).expand(batch_size, -1)
|
||||
return h
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def inference(self, encoder_out, tag_H):
|
||||
"""
|
||||
Inference on test images with beam search
|
||||
"""
|
||||
if hasattr(self, "_input_filter"):
|
||||
encoder_out = self._input_filter(encoder_out.permute(0, 3, 1, 2)).permute(
|
||||
0, 2, 3, 1
|
||||
)
|
||||
|
||||
encoder_dim = encoder_out.size(3)
|
||||
|
||||
# Flatten encoding (1, num_pixels, encoder_dim)
|
||||
encoder_out = encoder_out.view(1, -1, encoder_dim)
|
||||
|
||||
num_cells = len(tag_H)
|
||||
predictions_bboxes = []
|
||||
predictions_classes = []
|
||||
|
||||
for c_id in range(num_cells):
|
||||
# Start decoding
|
||||
h = self._init_hidden_state(encoder_out, 1)
|
||||
cell_tag_H = tag_H[c_id]
|
||||
awe, _ = self._attention(encoder_out, cell_tag_H, h)
|
||||
gate = self._sigmoid(self._f_beta(h))
|
||||
awe = gate * awe
|
||||
h = awe * h
|
||||
|
||||
predictions_bboxes.append(self._bbox_embed(h).sigmoid())
|
||||
predictions_classes.append(self._class_embed(h))
|
||||
if len(predictions_bboxes) > 0:
|
||||
predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes])
|
||||
if len(predictions_classes) > 0:
|
||||
predictions_classes = torch.stack([x[0] for x in predictions_classes])
|
||||
|
||||
return predictions_classes, predictions_bboxes
|
||||
@@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
class Encoder04(nn.Module):
|
||||
"""
|
||||
Encoder based on resnet-18
|
||||
"""
|
||||
|
||||
def __init__(self, enc_image_size, enc_dim=512):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
enc_image_size : int
|
||||
Assuming that the encoded image is a square, this is the length of the image side
|
||||
"""
|
||||
|
||||
super(Encoder04, self).__init__()
|
||||
self.enc_image_size = enc_image_size
|
||||
self._encoder_dim = enc_dim
|
||||
|
||||
resnet = torchvision.models.resnet18(pretrained=False)
|
||||
modules = list(resnet.children())[:-3]
|
||||
|
||||
self._resnet = nn.Sequential(*modules)
|
||||
self._adaptive_pool = nn.AdaptiveAvgPool2d(
|
||||
(self.enc_image_size, self.enc_image_size)
|
||||
)
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def get_encoder_dim(self):
|
||||
return self._encoder_dim
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
Forward propagation
|
||||
The encoder_dim 512 is decided by the structure of the image network (modified resnet-19)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : tensor (batch_size, image_channels, resized_image, resized_image)
|
||||
images input
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : (batch_size, enc_image_size, enc_image_size, 256)
|
||||
encoded images
|
||||
"""
|
||||
out = self._resnet(images) # (batch_size, 256, 28, 28)
|
||||
self._log().debug("forward: resnet out: {}".format(out.size()))
|
||||
out = self._adaptive_pool(out)
|
||||
out = out.permute(
|
||||
0, 2, 3, 1
|
||||
) # (batch_size, enc_image_size, enc_image_size, 256)
|
||||
|
||||
self._log().debug("enc forward: final out: {}".format(out.size()))
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,324 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
|
||||
from docling_ibm_models.tableformer.models.table04_rs.bbox_decoder_rs import BBoxDecoder
|
||||
from docling_ibm_models.tableformer.models.table04_rs.encoder04_rs import Encoder04
|
||||
from docling_ibm_models.tableformer.models.table04_rs.transformer_rs import (
|
||||
Tag_Transformer,
|
||||
)
|
||||
from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
|
||||
|
||||
LOG_LEVEL = logging.WARN
|
||||
# LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
class TableModel04_rs(BaseModel, nn.Module):
|
||||
r"""
|
||||
TableNet04Model encoder, dual-decoder model with OTSL+ support
|
||||
"""
|
||||
|
||||
def __init__(self, config, init_data, purpose, device):
|
||||
super(TableModel04_rs, self).__init__(config, init_data, device)
|
||||
|
||||
self._prof = config["predict"].get("profiling", False)
|
||||
self._device = device
|
||||
# Extract the word_map from the init_data
|
||||
word_map = init_data["word_map"]
|
||||
|
||||
# Encoder
|
||||
self._enc_image_size = config["model"]["enc_image_size"]
|
||||
self._encoder_dim = config["model"]["hidden_dim"]
|
||||
self._encoder = Encoder04(self._enc_image_size, self._encoder_dim).to(device)
|
||||
|
||||
tag_vocab_size = len(word_map["word_map_tag"])
|
||||
|
||||
td_encode = []
|
||||
for t in ["ecel", "fcel", "ched", "rhed", "srow"]:
|
||||
if t in word_map["word_map_tag"]:
|
||||
td_encode.append(word_map["word_map_tag"][t])
|
||||
self._log().debug("td_encode length: {}".format(len(td_encode)))
|
||||
self._log().debug("td_encode: {}".format(td_encode))
|
||||
|
||||
self._tag_attention_dim = config["model"]["tag_attention_dim"]
|
||||
self._tag_embed_dim = config["model"]["tag_embed_dim"]
|
||||
self._tag_decoder_dim = config["model"]["tag_decoder_dim"]
|
||||
self._decoder_dim = config["model"]["hidden_dim"]
|
||||
self._dropout = config["model"]["dropout"]
|
||||
|
||||
self._bbox = config["train"]["bbox"]
|
||||
self._bbox_attention_dim = config["model"]["bbox_attention_dim"]
|
||||
self._bbox_embed_dim = config["model"]["bbox_embed_dim"]
|
||||
self._bbox_decoder_dim = config["model"]["hidden_dim"]
|
||||
|
||||
self._enc_layers = config["model"]["enc_layers"]
|
||||
self._dec_layers = config["model"]["dec_layers"]
|
||||
self._n_heads = config["model"]["nheads"]
|
||||
|
||||
self._num_classes = config["model"]["bbox_classes"]
|
||||
self._enc_image_size = config["model"]["enc_image_size"]
|
||||
|
||||
self._max_pred_len = config["predict"]["max_steps"]
|
||||
|
||||
self._tag_transformer = Tag_Transformer(
|
||||
device,
|
||||
tag_vocab_size,
|
||||
td_encode,
|
||||
self._decoder_dim,
|
||||
self._enc_layers,
|
||||
self._dec_layers,
|
||||
self._enc_image_size,
|
||||
n_heads=self._n_heads,
|
||||
).to(device)
|
||||
|
||||
self._bbox_decoder = BBoxDecoder(
|
||||
device,
|
||||
self._bbox_attention_dim,
|
||||
self._bbox_embed_dim,
|
||||
self._tag_decoder_dim,
|
||||
self._bbox_decoder_dim,
|
||||
self._num_classes,
|
||||
self._encoder_dim,
|
||||
self._dropout,
|
||||
).to(device)
|
||||
|
||||
def _log(self):
|
||||
# Setup a custom logger
|
||||
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
||||
|
||||
def mergebboxes(self, bbox1, bbox2):
|
||||
new_w = (bbox2[0] + bbox2[2] / 2) - (bbox1[0] - bbox1[2] / 2)
|
||||
new_h = (bbox2[1] + bbox2[3] / 2) - (bbox1[1] - bbox1[3] / 2)
|
||||
|
||||
new_left = bbox1[0] - bbox1[2] / 2
|
||||
new_top = min((bbox2[1] - bbox2[3] / 2), (bbox1[1] - bbox1[3] / 2))
|
||||
|
||||
new_cx = new_left + new_w / 2
|
||||
new_cy = new_top + new_h / 2
|
||||
|
||||
bboxm = torch.tensor([new_cx, new_cy, new_w, new_h])
|
||||
return bboxm
|
||||
|
||||
def predict(self, imgs, max_steps, k, return_attention=False):
|
||||
r"""
|
||||
Inference.
|
||||
The input image must be preprocessed and transformed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img : tensor FloatTensor - torch.Size([1, 3, 448, 448])
|
||||
Input image for the inference
|
||||
|
||||
Returns
|
||||
-------
|
||||
seq : list
|
||||
Predictions for the tags as indices over the word_map
|
||||
outputs_class : tensor(x, 3)
|
||||
Classes of predicted bboxes. x is the number of bboxes. There are 3 bbox classes
|
||||
|
||||
outputs_coord : tensor(x, 4)
|
||||
Coords of predicted bboxes. x is the number of bboxes. Each bbox is in [cxcywh] format
|
||||
"""
|
||||
AggProfiler().begin("predict_total", self._prof)
|
||||
|
||||
# Invoke encoder
|
||||
self._tag_transformer.eval()
|
||||
enc_out = self._encoder(imgs)
|
||||
AggProfiler().end("model_encoder", self._prof)
|
||||
|
||||
word_map = self._init_data["word_map"]["word_map_tag"]
|
||||
n_heads = self._tag_transformer._n_heads
|
||||
# [1, 28, 28, 512]
|
||||
encoder_out = self._tag_transformer._input_filter(
|
||||
enc_out.permute(0, 3, 1, 2)
|
||||
).permute(0, 2, 3, 1)
|
||||
|
||||
batch_size = encoder_out.size(0)
|
||||
encoder_dim = encoder_out.size(-1)
|
||||
enc_inputs = encoder_out.view(batch_size, -1, encoder_dim).to(self._device)
|
||||
enc_inputs = enc_inputs.permute(1, 0, 2)
|
||||
positions = enc_inputs.shape[0]
|
||||
|
||||
encoder_mask = torch.zeros(
|
||||
(batch_size * n_heads, positions, positions), device=self._device
|
||||
) == torch.ones(
|
||||
(batch_size * n_heads, positions, positions), device=self._device
|
||||
)
|
||||
|
||||
# Invoking tag transformer encoder before the loop to save time
|
||||
AggProfiler().begin("model_tag_transformer_encoder", self._prof)
|
||||
encoder_out = self._tag_transformer._encoder(enc_inputs, mask=encoder_mask)
|
||||
AggProfiler().end("model_tag_transformer_encoder", self._prof)
|
||||
|
||||
decoded_tags = (
|
||||
torch.LongTensor([word_map["<start>"]]).to(self._device).unsqueeze(1)
|
||||
)
|
||||
output_tags = []
|
||||
cache = None
|
||||
tag_H_buf = []
|
||||
|
||||
skip_next_tag = True
|
||||
prev_tag_ucel = False
|
||||
line_num = 0
|
||||
|
||||
# Populate bboxes_to_merge, indexes of first lcel, and last cell in a span
|
||||
first_lcel = True
|
||||
bboxes_to_merge = {}
|
||||
cur_bbox_ind = -1
|
||||
bbox_ind = 0
|
||||
|
||||
# i = 0
|
||||
while len(output_tags) < self._max_pred_len:
|
||||
decoded_embedding = self._tag_transformer._embedding(decoded_tags)
|
||||
decoded_embedding = self._tag_transformer._positional_encoding(
|
||||
decoded_embedding
|
||||
)
|
||||
AggProfiler().begin("model_tag_transformer_decoder", self._prof)
|
||||
decoded, cache = self._tag_transformer._decoder(
|
||||
decoded_embedding,
|
||||
encoder_out,
|
||||
cache,
|
||||
memory_key_padding_mask=encoder_mask,
|
||||
)
|
||||
AggProfiler().end("model_tag_transformer_decoder", self._prof)
|
||||
# Grab last feature to produce token
|
||||
AggProfiler().begin("model_tag_transformer_fc", self._prof)
|
||||
logits = self._tag_transformer._fc(decoded[-1, :, :]) # 1, vocab_size
|
||||
AggProfiler().end("model_tag_transformer_fc", self._prof)
|
||||
new_tag = logits.argmax(1).item()
|
||||
|
||||
# STRUCTURE ERROR CORRECTION
|
||||
# Correction for first line xcel...
|
||||
if line_num == 0:
|
||||
if new_tag == word_map["xcel"]:
|
||||
new_tag = word_map["lcel"]
|
||||
|
||||
# Correction for ucel, lcel sequence...
|
||||
if prev_tag_ucel:
|
||||
if new_tag == word_map["lcel"]:
|
||||
new_tag = word_map["fcel"]
|
||||
|
||||
# End of generation
|
||||
if new_tag == word_map["<end>"]:
|
||||
output_tags.append(new_tag)
|
||||
decoded_tags = torch.cat(
|
||||
[
|
||||
decoded_tags,
|
||||
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
|
||||
],
|
||||
dim=0,
|
||||
) # current_output_len, 1
|
||||
break
|
||||
output_tags.append(new_tag)
|
||||
|
||||
# BBOX PREDICTION
|
||||
|
||||
# MAKE SURE TO SYNC NUMBER OF CELLS WITH NUMBER OF BBOXes
|
||||
if not skip_next_tag:
|
||||
if new_tag in [
|
||||
word_map["fcel"],
|
||||
word_map["ecel"],
|
||||
word_map["ched"],
|
||||
word_map["rhed"],
|
||||
word_map["srow"],
|
||||
word_map["nl"],
|
||||
word_map["ucel"],
|
||||
]:
|
||||
# GENERATE BBOX HERE TOO (All other cases)...
|
||||
tag_H_buf.append(decoded[-1, :, :])
|
||||
if first_lcel is not True:
|
||||
# Mark end index for horizontal cell bbox merge
|
||||
bboxes_to_merge[cur_bbox_ind] = bbox_ind
|
||||
bbox_ind += 1
|
||||
|
||||
# Treat horisontal span bboxes...
|
||||
if new_tag != word_map["lcel"]:
|
||||
first_lcel = True
|
||||
else:
|
||||
if first_lcel:
|
||||
# GENERATE BBOX HERE (Beginning of horisontal span)...
|
||||
tag_H_buf.append(decoded[-1, :, :])
|
||||
first_lcel = False
|
||||
# Mark start index for cell bbox merge
|
||||
cur_bbox_ind = bbox_ind
|
||||
bboxes_to_merge[cur_bbox_ind] = -1
|
||||
bbox_ind += 1
|
||||
|
||||
if new_tag in [word_map["nl"], word_map["ucel"], word_map["xcel"]]:
|
||||
skip_next_tag = True
|
||||
else:
|
||||
skip_next_tag = False
|
||||
|
||||
# Register ucel in sequence...
|
||||
if new_tag == word_map["ucel"]:
|
||||
prev_tag_ucel = True
|
||||
else:
|
||||
prev_tag_ucel = False
|
||||
|
||||
decoded_tags = torch.cat(
|
||||
[
|
||||
decoded_tags,
|
||||
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
|
||||
],
|
||||
dim=0,
|
||||
) # current_output_len, 1
|
||||
seq = decoded_tags.squeeze().tolist()
|
||||
|
||||
if self._bbox:
|
||||
AggProfiler().begin("model_bbox_decoder", self._prof)
|
||||
outputs_class, outputs_coord = self._bbox_decoder.inference(
|
||||
enc_out, tag_H_buf
|
||||
)
|
||||
AggProfiler().end("model_bbox_decoder", self._prof)
|
||||
else:
|
||||
outputs_class, outputs_coord = None, None
|
||||
|
||||
outputs_class.to(self._device)
|
||||
outputs_coord.to(self._device)
|
||||
|
||||
########################################################################################
|
||||
# Merge First and Last predicted BBOX for each span, according to bboxes_to_merge
|
||||
########################################################################################
|
||||
|
||||
outputs_class1 = []
|
||||
outputs_coord1 = []
|
||||
boxes_to_skip = []
|
||||
|
||||
for box_ind in range(len(outputs_coord)):
|
||||
box1 = outputs_coord[box_ind].to(self._device)
|
||||
cls1 = outputs_class[box_ind].to(self._device)
|
||||
if box_ind in bboxes_to_merge:
|
||||
box2 = outputs_coord[bboxes_to_merge[box_ind]].to(self._device)
|
||||
boxes_to_skip.append(bboxes_to_merge[box_ind])
|
||||
boxm = self.mergebboxes(box1, box2).to(self._device)
|
||||
outputs_coord1.append(boxm)
|
||||
outputs_class1.append(cls1)
|
||||
else:
|
||||
if box_ind not in boxes_to_skip:
|
||||
outputs_coord1.append(box1)
|
||||
outputs_class1.append(cls1)
|
||||
|
||||
if len(outputs_coord1) > 0:
|
||||
outputs_coord1 = torch.stack(outputs_coord1)
|
||||
if len(outputs_class1) > 0:
|
||||
outputs_class1 = torch.stack(outputs_class1)
|
||||
|
||||
outputs_class = outputs_class1
|
||||
outputs_coord = outputs_coord1
|
||||
|
||||
# Do the rest of the steps...
|
||||
AggProfiler().end("predict_total", self._prof)
|
||||
num_tab_cells = seq.count(4) + seq.count(5)
|
||||
num_rows = seq.count(9)
|
||||
self._log().info(
|
||||
"OTSL predicted table cells#: {}; rows#: {}".format(num_tab_cells, num_rows)
|
||||
)
|
||||
return seq, outputs_class, outputs_coord
|
||||
@@ -0,0 +1,203 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
import docling_ibm_models.tableformer.utils.utils as u
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, max_len=1024):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TMTransformerDecoder(nn.TransformerDecoder):
|
||||
def forward(
|
||||
self,
|
||||
tgt: Tensor,
|
||||
memory: Optional[Tensor] = None,
|
||||
cache: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
tgt (Tensor): encoded tags. (tags_len,bsz,hidden_dim)
|
||||
memory (Tensor): encoded image (enc_image_size,bsz,hidden_dim)
|
||||
cache (Optional[Tensor]): None during training, only used during inference.
|
||||
Returns:
|
||||
output (Tensor): (tags_len,bsz,hidden_dim)
|
||||
"""
|
||||
|
||||
output = tgt
|
||||
|
||||
# cache
|
||||
tag_cache = []
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(output, memory)
|
||||
tag_cache.append(output)
|
||||
if cache is not None:
|
||||
output = torch.cat([cache[i], output], dim=0)
|
||||
|
||||
if cache is not None:
|
||||
out_cache = torch.cat([cache, torch.stack(tag_cache, dim=0)], dim=1)
|
||||
else:
|
||||
out_cache = torch.stack(tag_cache, dim=0)
|
||||
|
||||
return output, out_cache
|
||||
|
||||
|
||||
class TMTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
tgt: Tensor,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
same as TMTransformerDecoder
|
||||
Returns:
|
||||
Tensor:
|
||||
During training (seq_len,bsz,hidden_dim)
|
||||
If eval mode: embedding of last tag: (1,bsz,hidden_dim)
|
||||
"""
|
||||
|
||||
# From PyTorch but modified to only use the last tag
|
||||
tgt_last_tok = tgt[-1:, :, :]
|
||||
|
||||
tmp_tgt = self.self_attn(
|
||||
tgt_last_tok,
|
||||
tgt,
|
||||
tgt,
|
||||
attn_mask=None, # None, because we only care about the last tag
|
||||
key_padding_mask=tgt_key_padding_mask,
|
||||
)[0]
|
||||
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
|
||||
tgt_last_tok = self.norm1(tgt_last_tok)
|
||||
|
||||
if memory is not None:
|
||||
tmp_tgt = self.multihead_attn(
|
||||
tgt_last_tok,
|
||||
memory,
|
||||
memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
|
||||
tgt_last_tok = self.norm2(tgt_last_tok)
|
||||
|
||||
tmp_tgt = self.linear2(
|
||||
self.dropout(self.activation(self.linear1(tgt_last_tok)))
|
||||
)
|
||||
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
|
||||
tgt_last_tok = self.norm3(tgt_last_tok)
|
||||
return tgt_last_tok
|
||||
|
||||
|
||||
class Tag_Transformer(nn.Module):
|
||||
"""
|
||||
"Attention Is All You Need" - https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
vocab_size,
|
||||
td_encode,
|
||||
embed_dim,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
enc_image_size,
|
||||
dropout=0.1,
|
||||
n_heads=4,
|
||||
dim_ff=1024,
|
||||
):
|
||||
|
||||
super(Tag_Transformer, self).__init__()
|
||||
|
||||
self._device = device
|
||||
self._n_heads = n_heads
|
||||
self._embedding = nn.Embedding(vocab_size, embed_dim)
|
||||
self._positional_encoding = PositionalEncoding(embed_dim)
|
||||
self._td_encode = td_encode
|
||||
|
||||
self._encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff
|
||||
),
|
||||
num_layers=encoder_layers,
|
||||
)
|
||||
|
||||
self._decoder = TMTransformerDecoder(
|
||||
TMTransformerDecoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=n_heads,
|
||||
dim_feedforward=dim_ff,
|
||||
),
|
||||
num_layers=decoder_layers,
|
||||
)
|
||||
|
||||
self._decoder_dim = embed_dim
|
||||
self._enc_image_size = enc_image_size
|
||||
self._input_filter = u.resnet_block(stride=1)
|
||||
self._fc = nn.Linear(embed_dim, vocab_size)
|
||||
|
||||
def inference(self, enc_inputs, tags, tag_lens, num_cells):
|
||||
# CNN backbone image encoding
|
||||
enc_inputs = self._input_filter(enc_inputs.permute(0, 3, 1, 2)).permute(
|
||||
0, 2, 3, 1
|
||||
)
|
||||
|
||||
batch_size = enc_inputs.size(0)
|
||||
encoder_dim = enc_inputs.size(-1)
|
||||
|
||||
enc_inputs = enc_inputs.view(batch_size, -1, encoder_dim).to(self._device)
|
||||
|
||||
enc_inputs = enc_inputs.permute(1, 0, 2)
|
||||
positions = enc_inputs.shape[0]
|
||||
# Transformer Encoder Encoded Image mask need to check if its useful
|
||||
encoder_mask = torch.zeros(
|
||||
(batch_size * self._n_heads, positions, positions), device=self._device
|
||||
) == torch.ones(
|
||||
(batch_size * self._n_heads, positions, positions), device=self._device
|
||||
)
|
||||
|
||||
# Transformer Encoder
|
||||
encoder_out = self._encoder(enc_inputs, mask=encoder_mask)
|
||||
|
||||
decode_lengths = (tag_lens - 1).tolist()
|
||||
|
||||
tgt = self._positional_encoding(self._embedding(tags).permute(1, 0, 2))
|
||||
|
||||
decoded = self._decoder(tgt, memory=encoder_out)
|
||||
decoded = decoded.permute(1, 0, 2)
|
||||
predictions = self._fc(decoded)
|
||||
return predictions, decode_lengths
|
||||
@@ -0,0 +1,541 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
from itertools import groupby
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
logger = s.get_custom_logger("consolidate", LOG_LEVEL)
|
||||
png_files = {} # Evaluation files
|
||||
total_pics = 0
|
||||
|
||||
|
||||
class bcolors:
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def otsl_clean(rs_list):
|
||||
new_rs_list = []
|
||||
stop_list = ["<pad>", "<unk>", "<start>", "<end>"]
|
||||
for tag in rs_list:
|
||||
if tag not in stop_list:
|
||||
new_rs_list.append(tag)
|
||||
return new_rs_list
|
||||
|
||||
|
||||
def otsl_sqr_chk(rs_list, name, logdebug):
|
||||
rs_list_split = [
|
||||
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
||||
]
|
||||
isSquare = True
|
||||
if len(rs_list_split) > 0:
|
||||
init_tag_len = len(rs_list_split[0]) + 1
|
||||
for ind, ln in enumerate(rs_list_split):
|
||||
ln.append("nl")
|
||||
if len(ln) != init_tag_len:
|
||||
isSquare = False
|
||||
if isSquare:
|
||||
if logdebug:
|
||||
print(
|
||||
"{}*OK* Table is square! *OK*{}".format(
|
||||
bcolors.OKGREEN, bcolors.ENDC
|
||||
)
|
||||
)
|
||||
else:
|
||||
err_name = "{}*ERR* " + name + " *ERR*{}"
|
||||
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
|
||||
print(
|
||||
"{}*ERR* Table is not square! *ERR*{}".format(
|
||||
bcolors.FAIL, bcolors.ENDC
|
||||
)
|
||||
)
|
||||
return isSquare
|
||||
|
||||
|
||||
def otsl_pad_to_sqr(rs_list, pad_tag):
|
||||
new_list = []
|
||||
rs_list_split = [
|
||||
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
||||
]
|
||||
max_row_len = 0
|
||||
for ind, ln in enumerate(rs_list_split):
|
||||
if len(ln) > max_row_len:
|
||||
max_row_len = len(ln)
|
||||
for ind, ln in enumerate(rs_list_split):
|
||||
ln += [pad_tag] * (max_row_len - len(ln))
|
||||
ln.append("nl")
|
||||
new_list.extend(ln)
|
||||
return new_list
|
||||
|
||||
|
||||
def otsl_tags_cells_sync_chk(rs_list, cells, name, logdebug):
|
||||
countCellTags = 0
|
||||
isGood = True
|
||||
for rsTag in rs_list:
|
||||
if rsTag in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
||||
countCellTags += 1
|
||||
if countCellTags != len(cells):
|
||||
err_name = "{}*!ERR* " + name + " *ERR!*{}"
|
||||
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
|
||||
err_msg = "{}*!ERR* Tags are not in sync with cells! *ERR!*{}"
|
||||
print(err_msg.format(bcolors.FAIL, bcolors.ENDC))
|
||||
isGood = False
|
||||
return isGood
|
||||
|
||||
|
||||
def otsl_check_down(rs_split, x, y):
|
||||
distance = 1
|
||||
elem = "ucel"
|
||||
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "lcel", "nl"]
|
||||
while elem not in goodlist and y < len(rs_split) - 1:
|
||||
y += 1
|
||||
distance += 1
|
||||
elem = rs_split[y][x]
|
||||
if elem in goodlist:
|
||||
distance -= 1
|
||||
return distance
|
||||
|
||||
|
||||
def otsl_check_right(rs_split, x, y):
|
||||
distance = 1
|
||||
elem = "lcel"
|
||||
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "ucel", "nl"]
|
||||
while elem not in goodlist and x < (len(rs_split[y]) - 1):
|
||||
x += 1
|
||||
distance += 1
|
||||
elem = rs_split[y][x]
|
||||
if elem in goodlist:
|
||||
distance -= 1
|
||||
return distance
|
||||
|
||||
|
||||
def otsl_to_html(rs_list, logdebug):
|
||||
if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
||||
# Most likely already HTML...
|
||||
return rs_list
|
||||
html_table = []
|
||||
if logdebug:
|
||||
print("{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC))
|
||||
|
||||
if not otsl_sqr_chk(rs_list, "---", logdebug):
|
||||
# PAD TABLE TO SQUARE
|
||||
print("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
|
||||
rs_list = otsl_pad_to_sqr(rs_list, "lcel")
|
||||
|
||||
# 2D structure, line by line:
|
||||
rs_list_split = [
|
||||
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
||||
]
|
||||
|
||||
if logdebug:
|
||||
print("")
|
||||
|
||||
# Sequentially store indexes of 2D spans that were registered to avoid re-registering them
|
||||
registry_2d_span = []
|
||||
|
||||
# Iterate all elements in the rs line, and look right / down to detect spans
|
||||
# If span detected - run function to find size of the span
|
||||
# repeat with all cells
|
||||
thead_present = False
|
||||
|
||||
for rs_row_ind, rs_row in enumerate(rs_list_split):
|
||||
html_list = []
|
||||
|
||||
if not thead_present:
|
||||
if "ched" in rs_list_split[rs_row_ind]:
|
||||
html_list.append("<thead>")
|
||||
thead_present = True
|
||||
|
||||
if thead_present:
|
||||
if "ched" not in rs_list_split[rs_row_ind]:
|
||||
html_list.append("</thead>")
|
||||
thead_present = False
|
||||
|
||||
html_list.append("<tr>")
|
||||
for rs_cell_ind, rs_cell in enumerate(rs_list_split[rs_row_ind]):
|
||||
if rs_cell in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
||||
rdist = 0
|
||||
ddist = 0
|
||||
xrdist = 0
|
||||
xddist = 0
|
||||
span = False
|
||||
# Check if it has horizontal span:
|
||||
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
|
||||
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "lcel":
|
||||
rdist = otsl_check_right(rs_list_split, rs_cell_ind, rs_row_ind)
|
||||
span = True
|
||||
# Check if it has vertical span:
|
||||
if rs_row_ind + 1 < len(rs_list_split):
|
||||
# print(">>>")
|
||||
# print(rs_list_split[rs_row_ind + 1])
|
||||
# print(">>> rs_cell_ind = {}".format(rs_cell_ind))
|
||||
if rs_list_split[rs_row_ind + 1][rs_cell_ind] == "ucel":
|
||||
ddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
|
||||
span = True
|
||||
# Check if it has 2D span:
|
||||
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
|
||||
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "xcel":
|
||||
xrdist = otsl_check_right(
|
||||
rs_list_split, rs_cell_ind, rs_row_ind
|
||||
)
|
||||
xddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
|
||||
span = True
|
||||
# Check if this 2D span was already registered,
|
||||
# If not - register, if yes - cancel span
|
||||
# print("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
|
||||
# print("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
|
||||
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
|
||||
for y in range(rs_row_ind, xddist + rs_row_ind):
|
||||
reg2dind = str(x) + "_" + str(y)
|
||||
# print(reg2dind)
|
||||
if reg2dind in registry_2d_span:
|
||||
# Cell of the span is already in, cancel current span
|
||||
span = False
|
||||
if span:
|
||||
# None of the span cells were previously registered
|
||||
# Register an entire span
|
||||
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
|
||||
for y in range(rs_row_ind, xddist + rs_row_ind):
|
||||
reg2dind = str(x) + "_" + str(y)
|
||||
registry_2d_span.append(reg2dind)
|
||||
if span:
|
||||
html_list.append("<td")
|
||||
if rdist > 1:
|
||||
html_list.append(' colspan="' + str(rdist) + '"')
|
||||
if ddist > 1:
|
||||
html_list.append(' rowspan="' + str(ddist) + '"')
|
||||
if xrdist > 1:
|
||||
html_list.append(' rowspan="' + str(xddist) + '"')
|
||||
html_list.append(' colspan="' + str(xrdist) + '"')
|
||||
html_list.append(">")
|
||||
html_list.append("</td>")
|
||||
else:
|
||||
html_list.append("<td>")
|
||||
html_list.append("</td>")
|
||||
html_list.append("</tr>")
|
||||
html_table.extend(html_list)
|
||||
|
||||
if logdebug:
|
||||
print("*********************** registry_2d_span ***************************")
|
||||
print(registry_2d_span)
|
||||
print("********************************************************************")
|
||||
|
||||
return html_table
|
||||
|
||||
|
||||
def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer):
|
||||
r"""
|
||||
Converts table structure from HTML to RS
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table : json
|
||||
line from jsonl
|
||||
writer : writer
|
||||
Writes lines into output jsonl
|
||||
"""
|
||||
|
||||
table_html_structure = copy.deepcopy(table["html"]["structure"])
|
||||
out_line = table
|
||||
if include_html:
|
||||
out_line["html"]["html_structure"] = table_html_structure
|
||||
out_line["html"]["html_restored_structure"] = {"tokens": []}
|
||||
|
||||
out_line["html"]["structure"] = {"tokens": []}
|
||||
# possible colspans
|
||||
pos_colspans = {
|
||||
' colspan="20"': 20,
|
||||
' colspan="19"': 19,
|
||||
' colspan="18"': 18,
|
||||
' colspan="17"': 17,
|
||||
' colspan="16"': 16,
|
||||
' colspan="15"': 15,
|
||||
' colspan="14"': 14,
|
||||
' colspan="13"': 13,
|
||||
' colspan="12"': 12,
|
||||
' colspan="11"': 11,
|
||||
' colspan="10"': 10,
|
||||
' colspan="2"': 2,
|
||||
' colspan="3"': 3,
|
||||
' colspan="4"': 4,
|
||||
' colspan="5"': 5,
|
||||
' colspan="6"': 6,
|
||||
' colspan="7"': 7,
|
||||
' colspan="8"': 8,
|
||||
' colspan="9"': 9,
|
||||
}
|
||||
# possible rowspans
|
||||
pos_rowspans = {
|
||||
' rowspan="20"': 20,
|
||||
' rowspan="19"': 19,
|
||||
' rowspan="18"': 18,
|
||||
' rowspan="17"': 17,
|
||||
' rowspan="16"': 16,
|
||||
' rowspan="15"': 15,
|
||||
' rowspan="14"': 14,
|
||||
' rowspan="13"': 13,
|
||||
' rowspan="12"': 12,
|
||||
' rowspan="11"': 11,
|
||||
' rowspan="10"': 10,
|
||||
' rowspan="2"': 2,
|
||||
' rowspan="3"': 3,
|
||||
' rowspan="4"': 4,
|
||||
' rowspan="5"': 5,
|
||||
' rowspan="6"': 6,
|
||||
' rowspan="7"': 7,
|
||||
' rowspan="8"': 8,
|
||||
' rowspan="9"': 9,
|
||||
}
|
||||
|
||||
t_cells = [] # 2D structure
|
||||
tl_cells = [] # 1D structure
|
||||
t_expands = [] # 2D structure
|
||||
tl_spans = {} # MAP, POPULATE WITH ACTUAL SPANS VALUES, IN SYNC WITH tl_cells
|
||||
|
||||
current_line = 0
|
||||
current_column = 0
|
||||
current_html_cell_ind = 0
|
||||
|
||||
current_line_tags = []
|
||||
current_line_expands = []
|
||||
|
||||
if logdebug:
|
||||
print("")
|
||||
print("*** {}: {} ***".format(table["split"], table["filename"]))
|
||||
|
||||
colnum = 0
|
||||
|
||||
if extra_debug:
|
||||
print("========================== Input HTML ============================")
|
||||
print(table_html_structure["tokens"])
|
||||
print("==================================================================")
|
||||
|
||||
if logdebug:
|
||||
print("********")
|
||||
print("* OTSL *")
|
||||
print("********")
|
||||
|
||||
for i in range(len(table_html_structure["tokens"])):
|
||||
html_tag = table_html_structure["tokens"][i]
|
||||
prev_html_tag = ""
|
||||
next_html_tag = ""
|
||||
if i > 0:
|
||||
prev_html_tag = table_html_structure["tokens"][i - 1]
|
||||
if i < len(table_html_structure["tokens"]) - 1:
|
||||
next_html_tag = table_html_structure["tokens"][i + 1]
|
||||
|
||||
if html_tag not in ["<thead>", "<tbody>"]:
|
||||
# Then check the next tag...
|
||||
# rules of conversion
|
||||
# Check up-cell in t_expands, in case row-spans have to be inserted
|
||||
if html_tag in ["<td>", "<td", "</tr>"]:
|
||||
if current_line > 0:
|
||||
if current_column >= len(t_expands[current_line - 1]):
|
||||
# !!!
|
||||
return False, {}
|
||||
up_expand = t_expands[current_line - 1][current_column]
|
||||
|
||||
while up_expand[1] > 0:
|
||||
if up_expand[0] == 0:
|
||||
# ucel
|
||||
current_line_tags.append("ucel")
|
||||
current_line_expands.append([0, up_expand[1] - 1])
|
||||
current_column += 1
|
||||
else:
|
||||
# xcel
|
||||
for ci in range(up_expand[0]):
|
||||
current_line_tags.append("xcel")
|
||||
current_line_expands.append(
|
||||
[up_expand[0] - ci, up_expand[1] - 1]
|
||||
)
|
||||
current_column += 1
|
||||
up_expand = t_expands[current_line - 1][current_column]
|
||||
# ======================================================================================
|
||||
# Fix for trailing "ucel" in a row
|
||||
if html_tag in ["</tr>"]:
|
||||
if current_line > 0:
|
||||
cur_line_len = len(current_line_expands)
|
||||
pre_line_len = len(t_expands[current_line - 1])
|
||||
|
||||
if cur_line_len < pre_line_len:
|
||||
extra_columns = pre_line_len - cur_line_len - 1
|
||||
if extra_columns > 0:
|
||||
if extra_debug:
|
||||
print(
|
||||
"Extra columns needed in row: {}".format(
|
||||
extra_columns
|
||||
)
|
||||
)
|
||||
|
||||
for clm in range(extra_columns):
|
||||
up_expand = t_expands[current_line - 1][
|
||||
cur_line_len + clm
|
||||
]
|
||||
if up_expand[0] == 0:
|
||||
# ucel
|
||||
current_line_tags.append("ucel")
|
||||
current_line_expands.append([0, up_expand[1] - 1])
|
||||
else:
|
||||
# xcel
|
||||
current_line_tags.append("xcel")
|
||||
current_line_expands.append(
|
||||
[up_expand[0], up_expand[1] - 1]
|
||||
)
|
||||
# ======================================================================================
|
||||
|
||||
# 1. Opening cell tags
|
||||
if html_tag in ["<td>", "<td"]:
|
||||
# check if cell is empty...
|
||||
cell_is_empty = True
|
||||
if "cells" in table["html"]:
|
||||
cell_tokens = table["html"]["cells"][current_html_cell_ind][
|
||||
"tokens"
|
||||
]
|
||||
else:
|
||||
cell_tokens = "f"
|
||||
|
||||
# Clean cell_tokens from trash:
|
||||
cell_tokens = list(filter(lambda a: a != "<i>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "<I>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "<b>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "<B>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != " ", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "</b>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "</B>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "</i>", cell_tokens))
|
||||
cell_tokens = list(filter(lambda a: a != "</I>", cell_tokens))
|
||||
|
||||
# Check if empty
|
||||
if len(cell_tokens) > 0:
|
||||
cell_is_empty = False
|
||||
if cell_is_empty:
|
||||
out_line["html"]["cells"][current_html_cell_ind]["tokens"] = []
|
||||
current_line_tags.append("ecel")
|
||||
current_line_expands.append([0, 0])
|
||||
else:
|
||||
current_line_tags.append("fcel")
|
||||
current_line_expands.append([0, 0])
|
||||
current_html_cell_ind += 1
|
||||
current_column += 1
|
||||
|
||||
# 2. Closing row tags
|
||||
if html_tag == "</tr>":
|
||||
if len(current_line_tags) > colnum:
|
||||
colnum = len(current_line_tags)
|
||||
# Save everything we read about the line to t_cells
|
||||
current_line_tags.append("nl")
|
||||
t_cells.append(copy.deepcopy(current_line_tags))
|
||||
tl_cells.extend(copy.deepcopy(current_line_tags))
|
||||
if logdebug:
|
||||
print(current_line_tags)
|
||||
current_line_tags = []
|
||||
|
||||
# Deal with expands
|
||||
current_line_expands.append([-1, -1])
|
||||
# Output spans metadata
|
||||
t_expands.append(copy.deepcopy(current_line_expands))
|
||||
current_line_expands = []
|
||||
|
||||
current_column = 0
|
||||
current_line += 1
|
||||
# 3. Colspans only
|
||||
if html_tag in pos_colspans:
|
||||
if prev_html_tag not in pos_rowspans:
|
||||
if next_html_tag not in pos_rowspans:
|
||||
colspan_len = pos_colspans[html_tag]
|
||||
tl_spans[current_html_cell_ind - 1] = [colspan_len, 1]
|
||||
current_line_expands[len(current_line_expands) - 1] = [
|
||||
colspan_len,
|
||||
0,
|
||||
]
|
||||
for ci in range(colspan_len - 1):
|
||||
current_line_tags.append("lcel")
|
||||
current_line_expands.append([colspan_len - ci - 1, 0])
|
||||
current_column += 1
|
||||
|
||||
# 4. Rowspans only
|
||||
if html_tag in pos_rowspans:
|
||||
if prev_html_tag not in pos_colspans:
|
||||
if next_html_tag not in pos_colspans:
|
||||
rowspan_len = pos_rowspans[html_tag]
|
||||
tl_spans[current_html_cell_ind - 1] = [1, rowspan_len]
|
||||
current_line_expands[len(current_line_expands) - 1] = [
|
||||
0,
|
||||
rowspan_len - 1,
|
||||
]
|
||||
|
||||
# 5. 2D spans
|
||||
if html_tag in pos_rowspans:
|
||||
rowspan_len = pos_rowspans[html_tag]
|
||||
if prev_html_tag in pos_colspans:
|
||||
colspan_len = pos_colspans[prev_html_tag]
|
||||
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
|
||||
newexp = [colspan_len, rowspan_len - 1]
|
||||
current_line_expands[len(current_line_expands) - 1] = newexp
|
||||
for ci in range(colspan_len - 1):
|
||||
current_line_tags.append("xcel")
|
||||
current_line_expands.append(
|
||||
[colspan_len - ci - 1, rowspan_len - 1]
|
||||
)
|
||||
if next_html_tag in pos_colspans:
|
||||
colspan_len = pos_colspans[next_html_tag]
|
||||
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
|
||||
newexp = [colspan_len, rowspan_len - 1]
|
||||
current_line_expands[len(current_line_expands) - 1] = newexp
|
||||
for ci in range(colspan_len - 1):
|
||||
current_line_tags.append("xcel")
|
||||
current_line_expands.append(
|
||||
[colspan_len - ci - 1, rowspan_len - 1]
|
||||
)
|
||||
|
||||
t_name = "*** {}: {} ***".format(table["split"], table["filename"])
|
||||
# check if square
|
||||
isSquare = otsl_sqr_chk(tl_cells, t_name, logdebug)
|
||||
# TODO: pad if not square?
|
||||
if not isSquare:
|
||||
tl_cells = otsl_pad_to_sqr(tl_cells, "fcel")
|
||||
# check if cells (bboxes) in sync:
|
||||
if "cells" in out_line["html"]:
|
||||
isGood = otsl_tags_cells_sync_chk(
|
||||
tl_cells, out_line["html"]["cells"], t_name, logdebug
|
||||
)
|
||||
# convert back to HTML
|
||||
rHTML = []
|
||||
if isSquare:
|
||||
rHTML = otsl_to_html(tl_cells, logdebug)
|
||||
out_line["html"]["html_restored_structure"]["tokens"] = rHTML
|
||||
|
||||
out_line["html"]["structure"]["tokens"] = tl_cells
|
||||
out_line["otsl_spans"] = tl_spans
|
||||
out_line["cols"] = colnum
|
||||
out_line["rows"] = len(t_cells)
|
||||
out_line["html_len"] = len(table_html_structure["tokens"])
|
||||
out_line["rs_len"] = len(tl_cells)
|
||||
# save converted line
|
||||
if use_writer:
|
||||
if isSquare:
|
||||
if isGood:
|
||||
writer.write(out_line)
|
||||
|
||||
if logdebug:
|
||||
print("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
|
||||
print(rHTML)
|
||||
# original HTML
|
||||
oHTML = out_line["html"]["html_structure"]
|
||||
print("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
|
||||
print(oHTML)
|
||||
|
||||
return True, out_line
|
||||
@@ -0,0 +1,90 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def get_custom_logger(logger_name, level, stream=sys.stdout):
|
||||
r"""
|
||||
Create a custom logger with a standard formatting
|
||||
|
||||
Inputs:
|
||||
- logger_name: Name of the logger. You can get the class name as self.__class__.__name__
|
||||
- level: logging level (e.g. logging.INFO, logging.DEBUG, etc.)
|
||||
- stream: One of sys.stdout or sys.stderr
|
||||
|
||||
Outputs:
|
||||
logger
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# Set the handler
|
||||
if not logger.hasHandlers():
|
||||
handler = logging.StreamHandler(stream)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
###################################################################################
|
||||
# System constants
|
||||
#
|
||||
|
||||
r"""
|
||||
This is a "generic" logger available to all scripts.
|
||||
It is encouraged that each class has it's own custom logger with the name of the class.
|
||||
You can use the "get_custom_logger" function to build a custom logger with a standard format.
|
||||
"""
|
||||
LOGGER = get_custom_logger("docling-pm", logging.INFO)
|
||||
|
||||
# Supported dataset types
|
||||
supported_datasets = ["TF_prepared"] # TF prepared dataset
|
||||
|
||||
# Split names
|
||||
TRAIN_SPLIT = "train"
|
||||
VAL_SPLIT = "val"
|
||||
TEST_SPLIT = "test"
|
||||
|
||||
# Prepared data parts and filename templates
|
||||
PREPARED_DATA_PARTS = {
|
||||
# Array with the bboxes (x1y1x2y2) for all cells of the images across all splits.
|
||||
# The bboxes are indexed with the filename.
|
||||
# Notices:
|
||||
# - The bboxes are NOT transformed.
|
||||
# - If the image filenames are the same across splits, there will be one one entry in the file
|
||||
"BBOXES": "BBOXES.json",
|
||||
# Image filenames used for train and val
|
||||
"IMAGES": "IMAGES.json",
|
||||
# Mean, std, variance as arrays of 3 (for each color)
|
||||
"STATISTICS": "STATISTICS_<POSTFIX>.json", # PRECOMPUTED
|
||||
# Bboxes of the cells in the form [1, x1, x2, y1, y2] or [0, 0, 0, 0, 0] in case of no box.
|
||||
"TRAIN_CELLBBOXES": "TRAIN_CELLBBOXES_<POSTFIX>.json", # NOT USED.
|
||||
# Array with arrays of the length + 2 of the original cells per image.
|
||||
"TRAIN_CELLLENS": "TRAIN_CELLLENS_<POSTFIX>.json",
|
||||
# Indices of the cells between <start> <end> and <pad> at the end.
|
||||
"TRAIN_CELLS": "TRAIN_CELLS_<POSTFIX>.json",
|
||||
# Array with the length + 2 of the original tags per image.
|
||||
"TRAIN_TAGLENS": "TRAIN_TAGLENS_<POSTFIX>.json",
|
||||
# Indices of the tags between <start> <end> and <pad> at the end.
|
||||
"TRAIN_TAGS": "TRAIN_TAGS_<POSTFIX>.json",
|
||||
# Ground truth for the evaluation dataset per eval image.
|
||||
"VAL": "VAL.json",
|
||||
# Vocabulary: Indices of the word_map_cells and word_map_tags
|
||||
"WORDMAP": "WORDMAP_<POSTFIX>.json", # PRECOMPUTED
|
||||
}
|
||||
|
||||
# Purposes
|
||||
TRAIN_PURPOSE = "train"
|
||||
VAL_PURPOSE = "val"
|
||||
TEST_PURPOSE = "test"
|
||||
PREDICT_PURPOSE = "predict"
|
||||
|
||||
# The DDP world size when we train in CPU with DDP enabled
|
||||
DDP_CPU_WORLD_SIZE = 2
|
||||
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
# LOG_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
def dataset_test(config):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
config : dictionary
|
||||
The configuration settings
|
||||
"""
|
||||
|
||||
# model_type = config["model"]["type"]
|
||||
# Create the device and the Dataset
|
||||
device = "cpu"
|
||||
dataset = TFDataset(config, "train")
|
||||
dataset.set_device(device)
|
||||
|
||||
# Loop over the data
|
||||
dataset.reset()
|
||||
dataset.shuffle()
|
||||
for i, batch in enumerate(dataset):
|
||||
print("Loading batch: {}".format(i))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = c.parse_arguments()
|
||||
dataset_test(config)
|
||||
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
from docling_ibm_models.tableformer.data_management.data_transformer import (
|
||||
DataTransformer,
|
||||
)
|
||||
|
||||
|
||||
def dump_np(img_np: np.array, fn, n=6):
|
||||
# Expect to receive a numpy array for an image with the shape [channels, rows, columns]
|
||||
s = img_np.shape
|
||||
if s[0] not in [1, 2, 3, 4] or len(s) != 3:
|
||||
print("Image of invalid shape: {}".format(s))
|
||||
return
|
||||
|
||||
channels = s[0]
|
||||
rows = s[1]
|
||||
cols = s[2]
|
||||
w = n + 6
|
||||
with open(fn, "w") as fd:
|
||||
for r in range(rows):
|
||||
for col in range(cols):
|
||||
for ch in range(channels):
|
||||
x = img_np[ch][r][col]
|
||||
if isinstance(x, np.float32):
|
||||
f_str = "0:>{}.{}f".format(w, n)
|
||||
elif isinstance(x, np.uint8):
|
||||
f_str = "0:>{}".format(w)
|
||||
else:
|
||||
return False
|
||||
|
||||
x_str = ("{" + f_str + "}").format(x)
|
||||
fd.write(x_str)
|
||||
if ch < channels - 1:
|
||||
fd.write(" ")
|
||||
fd.write("\n")
|
||||
return True
|
||||
|
||||
|
||||
def dump_channels(save_dir, fn_prefix, img_np: np.array):
|
||||
# Dump the np array into 3 files per channel
|
||||
img_np_ch0 = img_np[0, :, :]
|
||||
img_np_ch1 = img_np[1, :, :]
|
||||
img_np_ch2 = img_np[2, :, :]
|
||||
txt_ch0_fn = os.path.join(save_dir, fn_prefix + "_ch0.txt")
|
||||
txt_ch1_fn = os.path.join(save_dir, fn_prefix + "_ch1.txt")
|
||||
txt_ch2_fn = os.path.join(save_dir, fn_prefix + "_ch2.txt")
|
||||
np.savetxt(txt_ch0_fn, img_np_ch0)
|
||||
np.savetxt(txt_ch1_fn, img_np_ch1)
|
||||
np.savetxt(txt_ch2_fn, img_np_ch2)
|
||||
print(f"{txt_ch0_fn}")
|
||||
print(f"{txt_ch1_fn}")
|
||||
print(f"{txt_ch2_fn}")
|
||||
|
||||
|
||||
def prepare_image(config):
|
||||
transformer = DataTransformer(config)
|
||||
predict_dir = config["predict"]["predict_dir"]
|
||||
use_normalization = config["dataset"]["image_normalization"]["state"]
|
||||
|
||||
pattern = os.path.join(predict_dir, "*.png")
|
||||
for img_fn in glob.glob(pattern):
|
||||
print(f"img_fn: {img_fn}")
|
||||
|
||||
with Image.open(img_fn) as img:
|
||||
# Dump the initial image in txt files
|
||||
img_np = np.array(img)
|
||||
|
||||
# Reshape the image in order to print it
|
||||
img_np_m = np.moveaxis(img_np, 2, 0)
|
||||
print(
|
||||
"orig. img_np.shape: {}, reshaped image: {}".format(
|
||||
img_np.shape, img_np_m.shape
|
||||
)
|
||||
)
|
||||
original_fn = img_fn + "_python.txt"
|
||||
dump_np(img_np_m, original_fn)
|
||||
|
||||
r_img_ten = transformer.rescale_in_memory(img, use_normalization)
|
||||
print("npimgc: {} - {}".format(r_img_ten.type(), r_img_ten.size()))
|
||||
|
||||
# Dump the processed image tensor in txt files
|
||||
r_img_np = r_img_ten.numpy()
|
||||
|
||||
prepared_fn = img_fn + "_python_prepared.txt"
|
||||
dump_np(r_img_np, prepared_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = c.parse_arguments()
|
||||
prepare_image(config)
|
||||
@@ -0,0 +1,243 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import time
|
||||
from collections import deque
|
||||
from statistics import mean, median
|
||||
|
||||
|
||||
class SingletonClass(type):
|
||||
r"""
|
||||
Generic singleton metaclass
|
||||
"""
|
||||
|
||||
def __init__(self, name, bases, dic):
|
||||
self._instance = None
|
||||
super().__init__(name, bases, dic)
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# Create a singleton if needed
|
||||
if cls._instance is None:
|
||||
singleton = cls.__new__(cls)
|
||||
singleton.__init__(*args, **kwargs)
|
||||
cls._instance = singleton
|
||||
return cls._instance
|
||||
|
||||
|
||||
class Profiler:
|
||||
r"""
|
||||
Application specific profiler
|
||||
Decompose the application into "sections". Each section is a label.
|
||||
The total time a section consumes is split into "intervals"
|
||||
Use the `begin`, `end` methods to mark the begining and end of an interval for
|
||||
a certain section
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._section_dts = {} # section name -> sum(section intervals)
|
||||
self._section_calls = {} # section name -> number of invocations
|
||||
self._section_kB = {} # section name -> max kB of used heap
|
||||
|
||||
# section name -> beginning of the last interval
|
||||
self._last_begin = {}
|
||||
|
||||
def begin(self, section_name, enable=True):
|
||||
r"""
|
||||
Mark the beginning of an interval
|
||||
|
||||
Parameters
|
||||
----------
|
||||
section_name : string
|
||||
Name of the section
|
||||
enable : bool
|
||||
The actual interval entry takes place only if enable is true
|
||||
|
||||
Return
|
||||
------
|
||||
True if the interval has actuall begun
|
||||
"""
|
||||
if not enable:
|
||||
return False
|
||||
self._last_begin[section_name] = time.time()
|
||||
return True
|
||||
|
||||
def end(self, section_name, enable=True):
|
||||
r"""
|
||||
Mark the end of an interval for a certain section
|
||||
|
||||
Parameters
|
||||
----------
|
||||
section_name : string
|
||||
Name of the section
|
||||
enable : bool
|
||||
The actual interval entry takes place only if enable is true
|
||||
|
||||
Return
|
||||
------
|
||||
True if the section name is valid and an interval for this section has already begun
|
||||
False otherwise
|
||||
"""
|
||||
if not enable:
|
||||
return False
|
||||
if section_name not in self._last_begin:
|
||||
return False
|
||||
|
||||
dt = time.time() - self._last_begin[section_name]
|
||||
if section_name not in self._section_dts:
|
||||
self._section_dts[section_name] = dt
|
||||
self._section_calls[section_name] = 1
|
||||
else:
|
||||
self._section_dts[section_name] += dt
|
||||
self._section_calls[section_name] += 1
|
||||
|
||||
return True
|
||||
|
||||
def get_data(self, section_names=None):
|
||||
r"""
|
||||
Return a dict with profiling data for the specified sections.
|
||||
|
||||
Parameter
|
||||
---------
|
||||
section_names : list of string
|
||||
List with the section names to get their accumulative dt
|
||||
If it is None, all sections are returned
|
||||
|
||||
Return
|
||||
------
|
||||
dict of dicts
|
||||
Outer key: section name
|
||||
Inner keys: "dt": Accumulative time for that section, "cells": Number of calls
|
||||
"""
|
||||
# Filter the section names to apply
|
||||
filtered_names = list(
|
||||
filter(lambda x: x in section_names, self._section_dts.keys())
|
||||
if section_names is not None
|
||||
else self._section_dts.keys()
|
||||
)
|
||||
data = {}
|
||||
for section_name in filtered_names:
|
||||
data[section_name] = {
|
||||
"dt": self._section_dts[section_name],
|
||||
"calls": self._section_calls[section_name],
|
||||
"kB": self._section_kB[section_name],
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class AppProfiler(Profiler, metaclass=SingletonClass):
|
||||
r"""
|
||||
AppProfiler is a singleton of the Profiler for application wide usage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(AppProfiler, self).__init__()
|
||||
|
||||
|
||||
class AggProfiler(metaclass=SingletonClass):
|
||||
r"""
|
||||
Generic wrapper of Profiler that enables aggregation of profiling statistics around Cycles
|
||||
|
||||
- When a new cycle begins a new Profiler is created to keep the profiling data per section
|
||||
- Keep the last n cycles in a sliding window manner
|
||||
- At every time we can get profiling data about the last cycle and statistics over the last n
|
||||
cycles
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20):
|
||||
self._window_size = window_size
|
||||
# deque with up to the last "window_size" Profilers. The newest at index 0
|
||||
self._cycles = deque()
|
||||
|
||||
def start_agg(self, enable=True):
|
||||
r"""
|
||||
Returns
|
||||
-------
|
||||
0: not enabled
|
||||
1: a new scope has started
|
||||
"""
|
||||
if not enable:
|
||||
return 0
|
||||
|
||||
# Add a new profiler
|
||||
self._cycles.appendleft(Profiler())
|
||||
# In case the deque has grown too much, remove the oldest Profiler
|
||||
if len(self._cycles) > self._window_size:
|
||||
self._cycles.pop()
|
||||
return 1
|
||||
|
||||
def begin(self, section_name, enable=True):
|
||||
if not enable:
|
||||
return False
|
||||
if len(self._cycles) == 0:
|
||||
print("AggProfiler begin | Start Aggregator not initialized.")
|
||||
return False
|
||||
profiler = self._cycles[0]
|
||||
return profiler.begin(section_name)
|
||||
|
||||
def end(self, section_name, enable=True):
|
||||
if not enable:
|
||||
return False
|
||||
if len(self._cycles) == 0:
|
||||
print("AggProfiler end | Start Aggregator not initialized.")
|
||||
return False
|
||||
profiler = self._cycles[0]
|
||||
return profiler.end(section_name)
|
||||
|
||||
def get_data(self):
|
||||
r"""
|
||||
Get profiling data for:
|
||||
- The last cycle
|
||||
- Aggragated statistics (avg, median) per section and per metric across all cycles
|
||||
- The dt numbers for the mean/median is the average time for each section ACROSS the cycle
|
||||
- There is NO need to compute average by yourself.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with the structure:
|
||||
- window: int with the size of the time sliding window
|
||||
- last: dict with the metrics for the last cycle (as provided by the Profiler)
|
||||
- mean: dict with the mean metrics per section across the cycle
|
||||
- section_name
|
||||
- metric_name: mean of the metric values
|
||||
- median: dict with the median metrics per section across the cycle
|
||||
- section_name
|
||||
- metric_name: median of the metric values
|
||||
"""
|
||||
last_data = self._cycles[0].get_data()
|
||||
data = {
|
||||
"window": len(self._cycles),
|
||||
"last": last_data,
|
||||
"mean": {},
|
||||
"median": {},
|
||||
}
|
||||
|
||||
# Section -> metric -> [values]
|
||||
section_metric_values = {}
|
||||
|
||||
# Collect the metrics
|
||||
for i, p in enumerate(self._cycles):
|
||||
p_data = p.get_data()
|
||||
for section_name, m_dict in p_data.items():
|
||||
for m_name, m_val in m_dict.items():
|
||||
if section_name not in section_metric_values:
|
||||
section_metric_values[section_name] = {}
|
||||
s_metrics = section_metric_values[section_name]
|
||||
if m_name not in s_metrics:
|
||||
s_metrics[m_name] = []
|
||||
s_metrics[m_name].append(m_val)
|
||||
|
||||
# Aggregate the metrics
|
||||
for section_name, m_dict in section_metric_values.items():
|
||||
for m_name, m_values in m_dict.items():
|
||||
if section_name not in data["mean"]:
|
||||
data["mean"][section_name] = {}
|
||||
if section_name not in data["median"]:
|
||||
data["median"][section_name] = {}
|
||||
|
||||
mean_v = mean(m_values)
|
||||
median_v = median(m_values)
|
||||
data["mean"][section_name][m_name] = mean_v
|
||||
data["median"][section_name][m_name] = median_v
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,216 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import torch
|
||||
|
||||
|
||||
def model_info(model, verbose=False):
|
||||
# Plots a line-by-line description of a PyTorch model
|
||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
||||
n_g = sum(
|
||||
x.numel() for x in model.parameters() if x.requires_grad
|
||||
) # number gradients
|
||||
if verbose:
|
||||
print(
|
||||
"%5s %40s %9s %12s %20s %10s %10s"
|
||||
% ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
|
||||
)
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace("module_list.", "")
|
||||
print(
|
||||
"%5g %40s %9s %12g %20s %10.3g %10.3g"
|
||||
% (
|
||||
i,
|
||||
name,
|
||||
p.requires_grad,
|
||||
p.numel(),
|
||||
list(p.shape),
|
||||
p.mean(),
|
||||
p.std(),
|
||||
)
|
||||
)
|
||||
|
||||
try: # FLOPS
|
||||
from thop import profile
|
||||
|
||||
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
|
||||
fs = ", %.1f GFLOPS" % (macs / 1e9 * 2)
|
||||
except Exception:
|
||||
fs = ""
|
||||
|
||||
print(
|
||||
"Model Summary: %g layers, %g parameters, %g gradients%s"
|
||||
% (len(list(model.parameters())), n_p, n_g, fs)
|
||||
)
|
||||
|
||||
|
||||
# def init_seeds(seed=0):
|
||||
# torch.manual_seed(seed)
|
||||
#
|
||||
# # Reduce randomness (may be slower on Tesla GPUs)
|
||||
# # https://pytorch.org/docs/stable/notes/randomness.html
|
||||
# if seed == 0:
|
||||
# cudnn.deterministic = False
|
||||
# cudnn.benchmark = True
|
||||
#
|
||||
#
|
||||
# def select_device(device='', apex=False, batch_size=None):
|
||||
# # device = 'cpu' or '0' or '0,1,2,3'
|
||||
# cpu_request = device.lower() == 'cpu'
|
||||
# if device and not cpu_request: # if device requested other than 'cpu'
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
||||
# # check availablity
|
||||
# assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
|
||||
#
|
||||
# cuda = False if cpu_request else torch.cuda.is_available()
|
||||
# if cuda:
|
||||
# c = 1024 ** 2 # bytes to MB
|
||||
# ng = torch.cuda.device_count()
|
||||
# if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
||||
# assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % \
|
||||
# (batch_size, ng)
|
||||
# x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
||||
# # apex for mixed precision https://github.com/NVIDIA/apex
|
||||
# s = 'Using CUDA ' + ('Apex ' if apex else '')
|
||||
# for i in range(0, ng):
|
||||
# if i == 1:
|
||||
# s = ' ' * len(s)
|
||||
# print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
|
||||
# (s, i, x[i].name, x[i].total_memory / c))
|
||||
# else:
|
||||
# print('Using CPU')
|
||||
#
|
||||
# print('') # skip a line
|
||||
# return torch.device('cuda:0' if cuda else 'cpu')
|
||||
#
|
||||
#
|
||||
# def time_synchronized():
|
||||
# torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
# return time.time()
|
||||
#
|
||||
#
|
||||
# def initialize_weights(model):
|
||||
# for m in model.modules():
|
||||
# t = type(m)
|
||||
# if t is nn.Conv2d:
|
||||
# pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# elif t is nn.BatchNorm2d:
|
||||
# m.eps = 1e-4
|
||||
# m.momentum = 0.03
|
||||
# elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
||||
# m.inplace = True
|
||||
#
|
||||
#
|
||||
# def find_modules(model, mclass=nn.Conv2d):
|
||||
# # finds layer indices matching module class 'mclass'
|
||||
# return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
||||
#
|
||||
#
|
||||
# def fuse_conv_and_bn(conv, bn):
|
||||
# # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
# with torch.no_grad():
|
||||
# # init
|
||||
# fusedconv = torch.nn.Conv2d(conv.in_channels,
|
||||
# conv.out_channels,
|
||||
# kernel_size=conv.kernel_size,
|
||||
# stride=conv.stride,
|
||||
# padding=conv.padding,
|
||||
# bias=True)
|
||||
#
|
||||
# # prepare filters
|
||||
# w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
# w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
# fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
||||
#
|
||||
# # prepare spatial bias
|
||||
# if conv.bias is not None:
|
||||
# b_conv = conv.bias
|
||||
# else:
|
||||
# b_conv = torch.zeros(conv.weight.size(0))
|
||||
# b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
# fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
#
|
||||
# return fusedconv
|
||||
#
|
||||
#
|
||||
# def load_classifier(name='resnet101', n=2):
|
||||
# # Loads a pretrained model reshaped to n-class output
|
||||
# import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
|
||||
# model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
|
||||
#
|
||||
# # Display model properties
|
||||
# for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean',
|
||||
# 'model.std']:
|
||||
# print(x + ' =', eval(x))
|
||||
#
|
||||
# # Reshape output to n classes
|
||||
# filters = model.last_linear.weight.shape[1]
|
||||
# model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
|
||||
# model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
||||
# model.last_linear.out_features = n
|
||||
# return model
|
||||
#
|
||||
#
|
||||
# def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
|
||||
# # scales img(bs,3,y,x) by ratio
|
||||
# h, w = img.shape[2:]
|
||||
# s = (int(h * ratio), int(w * ratio)) # new size
|
||||
# img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
# if not same_shape: # pad/crop img
|
||||
# gs = 64 # (pixels) grid size
|
||||
# h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
||||
# return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
#
|
||||
#
|
||||
# class ModelEMA:
|
||||
# """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
||||
# Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
# This is intended to allow functionality like
|
||||
# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
# A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
# E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
||||
# RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
||||
# smoothing of weights to match results. Pay attention to the decay constant you are using
|
||||
# relative to your update count per epoch.
|
||||
# To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
# disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
# process, or after the training stops converging.
|
||||
# This class is sensitive where it is initialized in the sequence of model init,
|
||||
# GPU assignment and distributed training wrappers.
|
||||
# I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and
|
||||
# single-GPU.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, model, decay=0.9999, device=''):
|
||||
# # make a copy of the model for accumulating moving average of weights
|
||||
# self.ema = deepcopy(model)
|
||||
# self.ema.eval()
|
||||
# self.updates = 0 # number of EMA updates
|
||||
# # decay exponential ramp (to help early epochs)
|
||||
# self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
||||
# self.device = device # perform ema on different device from model if set
|
||||
# if device:
|
||||
# self.ema.to(device=device)
|
||||
# for p in self.ema.parameters():
|
||||
# p.requires_grad_(False)
|
||||
#
|
||||
# def update(self, model):
|
||||
# self.updates += 1
|
||||
# d = self.decay(self.updates)
|
||||
# with torch.no_grad():
|
||||
# if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||
# msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||
# else:
|
||||
# msd, esd = model.state_dict(), self.ema.state_dict()
|
||||
#
|
||||
# for k, v in esd.items():
|
||||
# if v.dtype.is_floating_point:
|
||||
# v *= d
|
||||
# v += (1. - d) * msd[k].detach()
|
||||
#
|
||||
# def update_attr(self, model):
|
||||
# # Assign attributes (which may change during training)
|
||||
# for k in model.__dict__.keys():
|
||||
# if not k.startswith('_'):
|
||||
# setattr(self.ema, k, getattr(model, k))
|
||||
@@ -0,0 +1,376 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision.models.resnet import BasicBlock, conv1x1
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def remove_padding(seq):
|
||||
r"""
|
||||
Remove the trailing zeros from the provided input
|
||||
|
||||
Parameters
|
||||
----------
|
||||
list: List of integers
|
||||
Predicted sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
list: List of integers
|
||||
The part of the input before the zero padding
|
||||
|
||||
"""
|
||||
pad_len = 0
|
||||
for x in reversed(seq):
|
||||
if x != 0:
|
||||
break
|
||||
pad_len += 1
|
||||
if pad_len == 0:
|
||||
return seq, 0
|
||||
|
||||
un_padded = seq[:-pad_len]
|
||||
return un_padded, pad_len
|
||||
|
||||
|
||||
def probabilities_to_predictions(probabilities):
|
||||
r"""
|
||||
Convert probabilities to predictions
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probabilities : Tensor[batch_size, vocab_size, seq_len]
|
||||
All log probabilities coming out at the last stage of the decoder
|
||||
|
||||
Returns
|
||||
-------
|
||||
predictions : tensor [batch_size, output_sequence_length]
|
||||
The prediceted trags
|
||||
|
||||
"""
|
||||
# max_idx: [batch_size, seq_len]
|
||||
max_idx = torch.argmax(probabilities, dim=1)
|
||||
return max_idx
|
||||
|
||||
|
||||
def print_target_predict(target, predictions, filenames=None, batch_idx=0):
|
||||
r"""
|
||||
For the Tags, print the target and predicted tensors for the specified batch index
|
||||
|
||||
We expect to have the batch size as the first dimension.
|
||||
Only the specified batch is extractred and the remaining dimenions are flattened.
|
||||
The results are printed as 2 lists with the target on top and the predictions below underlined
|
||||
|
||||
Parameters
|
||||
---------
|
||||
target : tensor [batch_size, output_sequence_length]
|
||||
The ground truth tags
|
||||
|
||||
predictions : tensor [batch_size, output_sequence_length]
|
||||
The prediceted trags
|
||||
|
||||
filenames : list of string
|
||||
The actual filename that provides the data
|
||||
|
||||
batch_idx : int
|
||||
Which index in the batch dimension will be printed
|
||||
"""
|
||||
target_flat = target[batch_idx].flatten()
|
||||
predictions_flat = predictions[batch_idx].flatten()
|
||||
target_label = "target"
|
||||
predict_label = "predict"
|
||||
if filenames is not None:
|
||||
target_label = filenames[batch_idx]
|
||||
label_len = max(len(target_label), len(predict_label))
|
||||
print("{}: {}".format(target_label.ljust(label_len, " "), target_flat.tolist()))
|
||||
print(
|
||||
"{}: {}".format(predict_label.ljust(label_len, " "), predictions_flat.tolist())
|
||||
)
|
||||
|
||||
|
||||
def load_image(full_fn):
|
||||
r"""
|
||||
Load an image from the disk as a numpy array
|
||||
|
||||
Parameters
|
||||
----------
|
||||
full_fn : string
|
||||
The full path filename of the image
|
||||
|
||||
Results
|
||||
-------
|
||||
img : numpy array: (channels, width, height)
|
||||
The loaded image as a numpy array
|
||||
"""
|
||||
with Image.open(full_fn) as f:
|
||||
img = np.asarray(f) # (width, height, channels)
|
||||
img = img.transpose(2, 0, 1) # (channels, width, height)
|
||||
return img
|
||||
|
||||
|
||||
def resnet_block(stride=1):
|
||||
layers = []
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(256, 512, stride),
|
||||
nn.BatchNorm2d(512),
|
||||
)
|
||||
layers.append(BasicBlock(256, 512, stride, downsample))
|
||||
layers.append(BasicBlock(512, 512, 1))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def repackage_hidden(h):
|
||||
r"""
|
||||
Wraps hidden states in new Tensors, to detach them from their history.
|
||||
"""
|
||||
if isinstance(h, torch.Tensor):
|
||||
return h.detach()
|
||||
else:
|
||||
return tuple(repackage_hidden(v) for v in h)
|
||||
|
||||
|
||||
def accuracy(scores, targets, k):
|
||||
"""
|
||||
Computes top-k accuracy, from predicted and true labels.
|
||||
|
||||
:param scores: scores from the model
|
||||
:param targets: true labels
|
||||
:param k: k in top-k accuracy
|
||||
:return: top-k accuracy
|
||||
"""
|
||||
|
||||
batch_size = targets.size(0)
|
||||
_, ind = scores.topk(k, 1, True, True)
|
||||
correct = ind.eq(targets.view(-1, 1).expand_as(ind))
|
||||
correct_total = correct.view(-1).float().sum() # 0D tensor
|
||||
return correct_total.item() * (100.0 / batch_size)
|
||||
|
||||
|
||||
def clip_gradient(optimizer, grad_clip):
|
||||
"""
|
||||
Clips gradients computed during backpropagation to avoid explosion of gradients.
|
||||
|
||||
:param optimizer: optimizer with the gradients to be clipped
|
||||
:param grad_clip: clip value
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is not None:
|
||||
param.grad.data.clamp_(-grad_clip, grad_clip)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Keeps track of most recent, average, sum, and count of a metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def bip_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Very simple multi-layer perceptron (also called FFN)"""
|
||||
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor:
|
||||
"""Generate the attention mask for causal decoding"""
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = (
|
||||
mask.float()
|
||||
.masked_fill(mask == 0, float("-inf"))
|
||||
.masked_fill(mask == 1, float(0.0))
|
||||
).to(device=device)
|
||||
return mask
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stops the training if validation loss doesn't improve after a given patience.
|
||||
Source from: https://github.com/Bjarten/early-stopping-pytorch
|
||||
"""
|
||||
|
||||
def __init__(self, patience=2, verbose=False, delta=0, trace_func=print):
|
||||
"""
|
||||
Args:
|
||||
patience (int): How long to wait after last time validation loss improved.
|
||||
Default: 7
|
||||
verbose (bool): If True, prints a message for each validation loss improvement.
|
||||
Default: False
|
||||
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
||||
Default: 0
|
||||
path (str): Path for the checkpoint to be saved to.
|
||||
Default: 'checkpoint.pt'
|
||||
trace_func (function): trace print function.
|
||||
Default: print
|
||||
"""
|
||||
self._patience = patience
|
||||
self._verbose = verbose
|
||||
self._counter = 0
|
||||
self._best_score = None
|
||||
self._early_stop = False
|
||||
self._val_loss_min = np.Inf
|
||||
self._delta = delta
|
||||
self._trace_func = trace_func
|
||||
|
||||
def __call__(self, val_loss):
|
||||
score = -val_loss
|
||||
save_checkpoint = True
|
||||
if self._best_score is None:
|
||||
self._best_score = score
|
||||
save_checkpoint = True
|
||||
if self._verbose:
|
||||
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
|
||||
self._trace_func(verb)
|
||||
self._val_loss_min = val_loss
|
||||
elif score < self._best_score + self._delta:
|
||||
self._counter += 1
|
||||
self._trace_func(
|
||||
f"EarlyStopping counter: {self._counter} out of {self._patience}"
|
||||
)
|
||||
if self._counter >= self._patience:
|
||||
self._early_stop = True
|
||||
save_checkpoint = False
|
||||
else:
|
||||
self._best_score = score
|
||||
save_checkpoint = True
|
||||
self._counter = 0
|
||||
if self._verbose:
|
||||
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
|
||||
self._trace_func(verb)
|
||||
self._val_loss_min = val_loss
|
||||
return save_checkpoint
|
||||
|
||||
|
||||
def print_dict(m: dict):
|
||||
r"""
|
||||
Print dict elements in separate lines sorted by keys
|
||||
"""
|
||||
if len(m) == 0:
|
||||
return
|
||||
|
||||
# Check if the key is a stringified integer
|
||||
first_key = next(iter(m))
|
||||
is_numeric = isinstance(first_key, str) and first_key.isnumeric()
|
||||
if is_numeric:
|
||||
keys = sorted([int(k) for k in m.keys()])
|
||||
else:
|
||||
keys = sorted([k for k in m.keys()])
|
||||
|
||||
for k in keys:
|
||||
if is_numeric:
|
||||
v = m[str(k)]
|
||||
else:
|
||||
v = m[k]
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
|
||||
def print_list(lst: list):
|
||||
r"""
|
||||
Print list elements in separate lines
|
||||
"""
|
||||
for i, elm in enumerate(lst):
|
||||
if isinstance(elm, list):
|
||||
print("{}: ({}) - {}".format(i, len(elm), elm))
|
||||
else:
|
||||
print("{}: {}".format(i, elm))
|
||||
@@ -0,0 +1,175 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import docling_ibm_models.tableformer.settings as s
|
||||
|
||||
LOG_LEVEL = logging.INFO
|
||||
|
||||
|
||||
class MyWelford:
|
||||
r"""
|
||||
Running computation of the sample mean and sample variance using Welford's algorithm
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._i = 0 # Running index
|
||||
self._m = 0 # Running mean
|
||||
self._s = 0 # (n - 1) * variance
|
||||
|
||||
def reset(self):
|
||||
r"""
|
||||
Reset the object
|
||||
"""
|
||||
self._i = 0
|
||||
self._m = 0
|
||||
self._s = 0
|
||||
|
||||
def add(self, xi):
|
||||
r"""
|
||||
Invoke add each time a new sample arrives
|
||||
|
||||
Inputs:
|
||||
xi: The next sample of data
|
||||
"""
|
||||
self._i += 1
|
||||
old_m = self._m
|
||||
self._m = self._m + (xi - self._m) / self._i
|
||||
self._s = self._s + (xi - self._m) * (xi - old_m)
|
||||
|
||||
def results(self):
|
||||
r"""
|
||||
Get the computed mean, variance and standard deviation up to now
|
||||
|
||||
Outputs:
|
||||
m: Sample mean
|
||||
v: Sample variance
|
||||
std: Sample standard deviation
|
||||
"""
|
||||
if self._i <= 1:
|
||||
return None, None, None
|
||||
|
||||
# v = self._s / (self._i - 1) # Sample variance
|
||||
v = self._s / (self._i) # Population variance
|
||||
std = np.sqrt(v)
|
||||
return self._m, v, std
|
||||
|
||||
|
||||
class MyWelfordImg(MyWelford):
|
||||
r"""
|
||||
Welford algorithm to calculate running mean and sample variance for images
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MyWelfordImg, self).__init__()
|
||||
|
||||
def add(self, img):
|
||||
r"""
|
||||
Input:
|
||||
img: An image numpy array (channel, width, height). The only requirement is to have the
|
||||
channels as the first dimension and have 3 dimensions in total
|
||||
"""
|
||||
channels = img.shape[0]
|
||||
flat_dim = img.shape[1] * img.shape[2]
|
||||
img_r = img.reshape(channels, flat_dim)
|
||||
|
||||
for i in range(flat_dim):
|
||||
super(MyWelfordImg, self).add(img_r[:, i])
|
||||
|
||||
|
||||
class ChanVarianceImg:
|
||||
r"""
|
||||
Chan's algorithm to compute a running variance with support of sub-samples
|
||||
In this implementation each sub-sample is an images
|
||||
|
||||
Math for the original paper:
|
||||
https://github.ibm.com/nli/variance_formulae
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
r""" """
|
||||
self._first = True
|
||||
# Size of the calculated dataset
|
||||
self._n = 0
|
||||
# Sum of the samples for the 3 image channels
|
||||
self._t = 0
|
||||
# Sum of the square differences of the deviations of the samples from the mean
|
||||
self._s = 0
|
||||
|
||||
def add(self, img):
|
||||
r"""
|
||||
Add the provided image to the computation of the dataset statistics
|
||||
|
||||
Input:
|
||||
img: An image numpy array (channel, width, height). The only requirement is to have the
|
||||
channels as the first dimension and have 3 dimensions in total
|
||||
"""
|
||||
ch = img.shape[0]
|
||||
n = img.shape[1] * img.shape[2]
|
||||
img = img.reshape(ch, n)
|
||||
img_t = img.sum(axis=1)
|
||||
img_t_v = img_t.reshape(ch, 1)
|
||||
diff = (img - (img_t_v / n)) ** 2
|
||||
img_s = diff.sum(axis=1)
|
||||
|
||||
if not self._first:
|
||||
c = (self._n / (n * (self._n + n))) * (
|
||||
((n / self._n) * self._t - img_t) ** 2
|
||||
)
|
||||
self._s += img_s + c
|
||||
self._t += img_t
|
||||
else:
|
||||
self._s = img_s
|
||||
self._t = img_t
|
||||
self._first = False
|
||||
self._n += n
|
||||
|
||||
def results(self):
|
||||
r"""
|
||||
Get the computed statistics
|
||||
|
||||
Output:
|
||||
mean: Mean for the complete dataset
|
||||
var: Population variance for the complete dataset
|
||||
std: Population standard deviation for the complete dataset
|
||||
"""
|
||||
mean = list(self._t / self._n)
|
||||
var = list(self._s / self._n) # Population variance
|
||||
std = list(np.sqrt(var))
|
||||
|
||||
return mean, var, std
|
||||
|
||||
def reset(self):
|
||||
r"""
|
||||
Reset the object to start over again
|
||||
"""
|
||||
self._n = 0
|
||||
self._t = 0
|
||||
self._s = 0
|
||||
self._first = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = s.get_custom_logger("variance", LOG_LEVEL)
|
||||
|
||||
n = 50000
|
||||
channels = 3
|
||||
width = 448
|
||||
height = 448
|
||||
|
||||
my = ChanVarianceImg()
|
||||
# Generate random images
|
||||
for i in range(n):
|
||||
logger.info(i)
|
||||
img = 255 * np.random.rand(channels, width, height)
|
||||
my.add(img)
|
||||
|
||||
# Calculate the statistics
|
||||
m, v, std = my.results()
|
||||
assert m.shape == (3,), "Wrong mean dimension"
|
||||
assert v.shape == (3,), "Wrong variance dimension"
|
||||
assert std.shape == (3,), "Wrong std dimension"
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 237 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 200 KiB |
Generated
+2787
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,66 @@
|
||||
[tool.poetry]
|
||||
name = "docling-ibm-models"
|
||||
version = "0.2.0"
|
||||
description = "This package contains the AI models used by the Docling PDF conversion package"
|
||||
authors = ["Nikos Livathinos <nli@zurich.ibm.com>", "Maxim Lysak <mly@zurich.ibm.com>", "Ahmed Nassar <ahn@zurich.ibm.com>", "Christoph Auer <cau@zurich.ibm.com>", "Michele Dolfi <dol@zurich.ibm.com>", "Peter Staar <taa@zurich.ibm.com>"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
keywords= ["docling", "convert", "document", "pdf", "layout model", "segmentation", "table structure", "table former"]
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Programming Language :: Python :: 3"
|
||||
]
|
||||
packages = [
|
||||
{ include = "docling_ibm_models" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
torch = "2.2.2"
|
||||
torchvision = "0.17.2"
|
||||
onnxruntime = "^1.16.2"
|
||||
numpy = "^1.24.4"
|
||||
lxml = "^4.9.1"
|
||||
jsonlines = "^3.1.0"
|
||||
Pillow = "^10.0.0"
|
||||
tqdm = "^4.64.0"
|
||||
apted = "^1.0.3"
|
||||
Distance = "^0.1.3"
|
||||
mean_average_precision = "^2021.4.26.0"
|
||||
opencv-python-headless = { version = "^4.9.0.80", markers = 'sys_platform=="linux"'}
|
||||
opencv-python = { version = "^4.9.0.80", markers = 'sys_platform!="linux"'}
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = {extras = ["jupyter"], version = "^24.4.2"}
|
||||
pytest = "^7.2.2"
|
||||
pre-commit = "^3.7.1"
|
||||
mypy = "^1.10.1"
|
||||
isort = "^5.10.1"
|
||||
python-semantic-release = "^7.32.2"
|
||||
flake8 = "^6.0.0"
|
||||
pyproject-flake8 = "^6.0.0"
|
||||
pytest-xdist = "^3.3.1"
|
||||
pytest-flake8 = "^1.1.0"
|
||||
types-requests = "^2.31.0.2"
|
||||
flake8-pyproject = "^1.2.3"
|
||||
pylint = "^2.17.5"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py311"]
|
||||
include = '\.pyi?$'
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 88
|
||||
py_version=311
|
||||
@@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import docling_ibm_models.tableformer.common as c
|
||||
|
||||
test_config_a = {
|
||||
"base_dir": "./tests/test_data/",
|
||||
"curr_dir": "./tests/test_data/test_common/",
|
||||
"data_top_dir": "./tests/test_data/",
|
||||
"dataset": {
|
||||
"name": ["PhysRevB"],
|
||||
"limit": 10,
|
||||
"split": {"test": 0.2, "train": 0.5, "evaluate": 0.3},
|
||||
},
|
||||
"features": {
|
||||
"name": "Data2Features03b",
|
||||
"parameters": {
|
||||
"normalize_features": True,
|
||||
"normalize_features_method": "Z-Score",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
test_config_b = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 30}}
|
||||
|
||||
test_config_c = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 302}}
|
||||
|
||||
test_config_d = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 303}}
|
||||
|
||||
|
||||
def test_safe_get_parameters():
|
||||
val = c.safe_get_parameter(None, None, 10)
|
||||
assert val == 10, "Failed with null objects"
|
||||
|
||||
index_path = ["features", "parameters", "normalize_features_method"]
|
||||
val = c.safe_get_parameter(test_config_a, index_path, None)
|
||||
assert val == "Z-Score", "Cannot find existing parameter"
|
||||
|
||||
index_path = ["features", "parameters", "wrong"]
|
||||
val = c.safe_get_parameter(test_config_a, index_path, "hello")
|
||||
assert val == "hello", "Default value should be here"
|
||||
|
||||
index_path = ["features", "wrong", "normalize_features_method"]
|
||||
val = c.safe_get_parameter(test_config_a, index_path, 10)
|
||||
assert val == 10, "Default value should be here"
|
||||
|
||||
index_path = ["model", "parameters", "normalize_features_method"]
|
||||
val = c.safe_get_parameter(test_config_a, index_path, "hello")
|
||||
assert val == "hello", "Default value should be here"
|
||||
|
||||
# Test exception throwing
|
||||
exRaised = False
|
||||
try:
|
||||
index_path = ["missing"]
|
||||
val = c.safe_get_parameter(test_config_a, index_path, required=True)
|
||||
except ValueError:
|
||||
exRaised = True
|
||||
assert exRaised, "Exception should had been raised here"
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
configs = [test_config_b, test_config_c, test_config_d]
|
||||
|
||||
for i, config in enumerate(configs):
|
||||
try:
|
||||
val = c.validate_config(config)
|
||||
if i == 0 or i == 1:
|
||||
assert val, "Valid configuration didn't pass the validation test"
|
||||
except AssertionError:
|
||||
assert i == 2, "Configuration validation error"
|
||||
@@ -0,0 +1 @@
|
||||
Put model check: "otslp_all_fast_clean.check" in this directory
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 389 KiB |
File diff suppressed because one or more lines are too long
Binary file not shown.
|
After Width: | Height: | Size: 170 KiB |
File diff suppressed because one or more lines are too long
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import docling_ibm_models.layoutmodel.layout_predictor as lp
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init() -> dict:
|
||||
r"""
|
||||
Initialize the testing environment
|
||||
"""
|
||||
init = {
|
||||
"artifact_path": "tests/test_data/model_artifacts/",
|
||||
"num_threads": 1,
|
||||
"test_imgs": [
|
||||
"tests/test_data/samples/ADS.2007.page_123.png",
|
||||
],
|
||||
"info1": {
|
||||
"onnx_file": os.path.join(
|
||||
"tests/test_data/model_artifacts/", lp.MODEL_CHECKPOINT_FN
|
||||
),
|
||||
"intra_op_num_threads": 2,
|
||||
"providers": ["CPUExecutionProvider"],
|
||||
"use_cpu_only": True,
|
||||
"image_size": 640,
|
||||
"threshold": 0.6,
|
||||
},
|
||||
"info2": {
|
||||
"onnx_file": os.path.join(
|
||||
"tests/test_data/model_artifacts/", lp.MODEL_CHECKPOINT_FN
|
||||
),
|
||||
"intra_op_num_threads": 1,
|
||||
"providers": ["CPUExecutionProvider"],
|
||||
"use_cpu_only": True,
|
||||
"image_size": 640,
|
||||
"threshold": 0.6,
|
||||
},
|
||||
"pred_bboxes": 9,
|
||||
}
|
||||
return init
|
||||
|
||||
|
||||
def test_layoutpredictor(init: dict):
|
||||
r"""
|
||||
Unit test for the LayoutPredictor
|
||||
"""
|
||||
# Initialize LayoutPredictor with envvars
|
||||
os.environ["USE_CPU_ONLY"] = ""
|
||||
os.environ["OMP_NUM_THREADS"] = "2"
|
||||
lpredictor = LayoutPredictor(init["artifact_path"])
|
||||
assert init["info1"] == lpredictor.info()
|
||||
|
||||
# Initialize LayoutPredictor with optional parameters
|
||||
lpredictor = LayoutPredictor(
|
||||
init["artifact_path"], num_threads=init["num_threads"], use_cpu_only=True
|
||||
)
|
||||
assert init["info2"] == lpredictor.info()
|
||||
|
||||
# Unsupported input image
|
||||
is_exception = False
|
||||
try:
|
||||
for pred in lpredictor.predict("wrong"):
|
||||
pass
|
||||
except TypeError:
|
||||
is_exception = True
|
||||
assert is_exception
|
||||
|
||||
# Predict on the test image
|
||||
for img_fn in init["test_imgs"]:
|
||||
with Image.open(img_fn) as img:
|
||||
# Load images as PIL objects
|
||||
for i, pred in enumerate(lpredictor.predict(img)):
|
||||
print("PIL pred: {}".format(pred))
|
||||
assert i + 1 == init["pred_bboxes"]
|
||||
|
||||
# Load images as numpy arrays
|
||||
np_arr = np.asarray(img)
|
||||
for i, pred in enumerate(lpredictor.predict(np_arr)):
|
||||
print("numpy pred: {}".format(pred))
|
||||
assert i + 1 == init["pred_bboxes"]
|
||||
@@ -0,0 +1,627 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
import docling_ibm_models.tableformer.data_management.tf_predictor as tf_predictor
|
||||
from docling_ibm_models.tableformer.data_management.tf_predictor import \
|
||||
TFPredictor
|
||||
|
||||
"""
|
||||
- Implements TF predictor to accept the input format from IOCR, e.g.
|
||||
"./tests/test_data/samples/tf_table_example_0.json" (trivial table crop)
|
||||
- Shape the output format like GTE would: e.g.
|
||||
"tests/test_data/samples/tf_gte_output_2.json" (Note: full form image)
|
||||
"""
|
||||
|
||||
docling_api_data = {
|
||||
"table_jsons": [
|
||||
"./tests/test_data/samples/ADS.2007.page_123.png_iocr.parse_format.json",
|
||||
"./tests/test_data/samples/PHM.2013.page_30.png_iocr.parse_format.json",
|
||||
],
|
||||
"png_images": [
|
||||
"./tests/test_data/samples/ADS.2007.page_123.png",
|
||||
"./tests/test_data/samples/PHM.2013.page_30.png",
|
||||
],
|
||||
"table_bboxes": [
|
||||
[[178, 748, 1061, 976], [177, 1163, 1062, 1329]],
|
||||
[[100, 186, 1135, 525]],
|
||||
],
|
||||
}
|
||||
|
||||
test_config = {
|
||||
"dataset": {
|
||||
"type": "TF_prepared",
|
||||
"name": "TF",
|
||||
"raw_data_dir": "./tests/test_data/model_artifacts/",
|
||||
"load_cells": True,
|
||||
"bbox_format": "5plet",
|
||||
"resized_image": 448,
|
||||
"keep_AR": False,
|
||||
"up_scaling_enabled": True,
|
||||
"down_scaling_enabled": True,
|
||||
"padding_mode": "null",
|
||||
"padding_color": [0, 0, 0],
|
||||
"image_normalization": {
|
||||
"state": True,
|
||||
"mean": [0.94247851, 0.94254675, 0.94292611],
|
||||
"std": [0.17910956, 0.17940403, 0.17931663],
|
||||
},
|
||||
"color_jitter": True,
|
||||
"rand_crop": True,
|
||||
"rand_pad": True,
|
||||
"image_grayscale": False,
|
||||
},
|
||||
"model": {
|
||||
"type": "TableModel04_rs",
|
||||
"name": "14_128_256_4_true",
|
||||
"save_dir": "./tests/test_data/model_artifacts/",
|
||||
"backbone": "resnet18",
|
||||
"enc_image_size": 28,
|
||||
"tag_embed_dim": 16,
|
||||
"hidden_dim": 512,
|
||||
"tag_decoder_dim": 512,
|
||||
"bbox_embed_dim": 256,
|
||||
"tag_attention_dim": 256,
|
||||
"bbox_attention_dim": 512,
|
||||
"enc_layers": 4, # 6
|
||||
"dec_layers": 2, # 6
|
||||
"nheads": 8,
|
||||
"dropout": 0.1,
|
||||
"bbox_classes": 2,
|
||||
},
|
||||
"train": {
|
||||
"save_periodicity": 1,
|
||||
"disable_cuda": False,
|
||||
"epochs": 23,
|
||||
"batch_size": 50,
|
||||
"clip_gradient": 0.1,
|
||||
"clip_max_norm": 0.1,
|
||||
"bbox": True,
|
||||
"validation": False,
|
||||
},
|
||||
"predict": {
|
||||
"max_steps": 1024,
|
||||
"beam_size": 5,
|
||||
"bbox": True,
|
||||
"predict_dir": "./tests/test_data/samples",
|
||||
"pdf_cell_iou_thres": 0.05,
|
||||
"padding": False,
|
||||
"padding_size": 50,
|
||||
"disable_post_process": False,
|
||||
"profiling": False,
|
||||
"device_mode": "auto",
|
||||
},
|
||||
"dataset_wordmap": {
|
||||
"word_map_tag": {
|
||||
"<pad>": 0,
|
||||
"<unk>": 1,
|
||||
"<start>": 2,
|
||||
"<end>": 3,
|
||||
"ecel": 4,
|
||||
"fcel": 5,
|
||||
"lcel": 6,
|
||||
"ucel": 7,
|
||||
"xcel": 8,
|
||||
"nl": 9,
|
||||
"ched": 10,
|
||||
"rhed": 11,
|
||||
"srow": 12,
|
||||
},
|
||||
"word_map_cell": {
|
||||
" ": 13,
|
||||
"!": 179,
|
||||
'"': 126,
|
||||
"#": 101,
|
||||
"$": 119,
|
||||
"%": 18,
|
||||
"&": 114,
|
||||
"'": 108,
|
||||
"(": 29,
|
||||
")": 32,
|
||||
"*": 26,
|
||||
"+": 97,
|
||||
",": 71,
|
||||
"-": 63,
|
||||
".": 34,
|
||||
"/": 66,
|
||||
"0": 33,
|
||||
"1": 36,
|
||||
"2": 43,
|
||||
"3": 41,
|
||||
"4": 45,
|
||||
"5": 17,
|
||||
"6": 37,
|
||||
"7": 35,
|
||||
"8": 40,
|
||||
"9": 16,
|
||||
":": 88,
|
||||
";": 92,
|
||||
"<": 73,
|
||||
"</b>": 9,
|
||||
"</i>": 23,
|
||||
"</overline>": 219,
|
||||
"</strike>": 233,
|
||||
"</sub>": 94,
|
||||
"</sup>": 77,
|
||||
"</underline>": 151,
|
||||
"<b>": 1,
|
||||
"<end>": 280,
|
||||
"<i>": 21,
|
||||
"<overline>": 218,
|
||||
"<pad>": 0,
|
||||
"<start>": 279,
|
||||
"<strike>": 232,
|
||||
"<sub>": 93,
|
||||
"<sup>": 75,
|
||||
"<underline>": 150,
|
||||
"<unk>": 278,
|
||||
"=": 99,
|
||||
">": 39,
|
||||
"?": 96,
|
||||
"@": 125,
|
||||
"A": 27,
|
||||
"B": 86,
|
||||
"C": 19,
|
||||
"D": 57,
|
||||
"E": 64,
|
||||
"F": 47,
|
||||
"G": 44,
|
||||
"H": 10,
|
||||
"I": 20,
|
||||
"J": 80,
|
||||
"K": 81,
|
||||
"L": 52,
|
||||
"M": 46,
|
||||
"N": 69,
|
||||
"O": 65,
|
||||
"P": 62,
|
||||
"Q": 59,
|
||||
"R": 60,
|
||||
"S": 58,
|
||||
"T": 48,
|
||||
"U": 55,
|
||||
"V": 2,
|
||||
"W": 83,
|
||||
"X": 104,
|
||||
"Y": 89,
|
||||
"Z": 113,
|
||||
"[": 70,
|
||||
"\\": 165,
|
||||
"]": 72,
|
||||
"^": 132,
|
||||
"_": 84,
|
||||
"`": 196,
|
||||
"a": 3,
|
||||
"b": 6,
|
||||
"c": 54,
|
||||
"d": 12,
|
||||
"e": 8,
|
||||
"f": 50,
|
||||
"g": 28,
|
||||
"h": 56,
|
||||
"i": 5,
|
||||
"j": 82,
|
||||
"k": 95,
|
||||
"l": 7,
|
||||
"m": 30,
|
||||
"n": 31,
|
||||
"o": 15,
|
||||
"p": 22,
|
||||
"q": 67,
|
||||
"r": 4,
|
||||
"s": 51,
|
||||
"t": 14,
|
||||
"u": 25,
|
||||
"v": 24,
|
||||
"w": 53,
|
||||
"x": 61,
|
||||
"y": 49,
|
||||
"z": 11,
|
||||
"{": 158,
|
||||
"|": 139,
|
||||
"}": 159,
|
||||
"~": 147,
|
||||
"\u00a2": 203,
|
||||
"\u00a3": 162,
|
||||
"\u00a4": 220,
|
||||
"\u00a5": 176,
|
||||
"\u00a7": 142,
|
||||
"\u00a9": 268,
|
||||
"\u00ab": 239,
|
||||
"\u00ad": 275,
|
||||
"\u00ae": 130,
|
||||
"\u00b0": 100,
|
||||
"\u00b1": 79,
|
||||
"\u00b6": 171,
|
||||
"\u00b7": 137,
|
||||
"\u00bb": 240,
|
||||
"\u00d7": 118,
|
||||
"\u00d8": 192,
|
||||
"\u00df": 197,
|
||||
"\u00e6": 261,
|
||||
"\u00f7": 225,
|
||||
"\u00f8": 163,
|
||||
"\u0131": 242,
|
||||
"\u0142": 267,
|
||||
"\u01c2": 211,
|
||||
"\u025b": 223,
|
||||
"\u02b9": 248,
|
||||
"\u02c2": 195,
|
||||
"\u02c3": 208,
|
||||
"\u02c6": 253,
|
||||
"\u0300": 209,
|
||||
"\u0301": 131,
|
||||
"\u0302": 138,
|
||||
"\u0303": 156,
|
||||
"\u0304": 152,
|
||||
"\u0306": 222,
|
||||
"\u0307": 247,
|
||||
"\u0308": 103,
|
||||
"\u030a": 102,
|
||||
"\u030c": 254,
|
||||
"\u0327": 155,
|
||||
"\u0328": 269,
|
||||
"\u0338": 170,
|
||||
"\u0391": 173,
|
||||
"\u0392": 169,
|
||||
"\u0393": 180,
|
||||
"\u0394": 85,
|
||||
"\u0398": 243,
|
||||
"\u0399": 271,
|
||||
"\u039b": 272,
|
||||
"\u03a0": 213,
|
||||
"\u03a3": 185,
|
||||
"\u03a6": 148,
|
||||
"\u03a7": 212,
|
||||
"\u03a8": 141,
|
||||
"\u03a9": 161,
|
||||
"\u03b1": 90,
|
||||
"\u03b2": 107,
|
||||
"\u03b3": 110,
|
||||
"\u03b4": 153,
|
||||
"\u03b5": 166,
|
||||
"\u03b6": 178,
|
||||
"\u03b7": 146,
|
||||
"\u03b8": 186,
|
||||
"\u03b9": 229,
|
||||
"\u03ba": 164,
|
||||
"\u03bb": 91,
|
||||
"\u03bc": 78,
|
||||
"\u03bd": 230,
|
||||
"\u03be": 244,
|
||||
"\u03c0": 127,
|
||||
"\u03c1": 149,
|
||||
"\u03c3": 116,
|
||||
"\u03c4": 198,
|
||||
"\u03c5": 189,
|
||||
"\u03c6": 140,
|
||||
"\u03c7": 124,
|
||||
"\u03c8": 216,
|
||||
"\u03c9": 167,
|
||||
"\u0410": 273,
|
||||
"\u0421": 194,
|
||||
"\u115f": 217,
|
||||
"\u200b": 265,
|
||||
"\u2010": 117,
|
||||
"\u2012": 135,
|
||||
"\u2013": 42,
|
||||
"\u2014": 106,
|
||||
"\u2015": 228,
|
||||
"\u2016": 259,
|
||||
"\u2018": 123,
|
||||
"\u2019": 121,
|
||||
"\u201c": 87,
|
||||
"\u201d": 115,
|
||||
"\u201e": 245,
|
||||
"\u2020": 109,
|
||||
"\u2021": 129,
|
||||
"\u2022": 128,
|
||||
"\u2028": 190,
|
||||
"\u2030": 154,
|
||||
"\u2032": 68,
|
||||
"\u203b": 224,
|
||||
"\u2044": 188,
|
||||
"\u204e": 199,
|
||||
"\u2061": 200,
|
||||
"\u20ac": 184,
|
||||
"\u2190": 202,
|
||||
"\u2191": 112,
|
||||
"\u2192": 120,
|
||||
"\u2193": 111,
|
||||
"\u2194": 183,
|
||||
"\u21d1": 266,
|
||||
"\u21d2": 264,
|
||||
"\u21d3": 255,
|
||||
"\u2205": 215,
|
||||
"\u2206": 175,
|
||||
"\u2208": 262,
|
||||
"\u2211": 160,
|
||||
"\u2212": 76,
|
||||
"\u2216": 206,
|
||||
"\u2217": 105,
|
||||
"\u2218": 246,
|
||||
"\u2219": 236,
|
||||
"\u221a": 187,
|
||||
"\u221e": 207,
|
||||
"\u2223": 260,
|
||||
"\u2225": 193,
|
||||
"\u2227": 182,
|
||||
"\u2229": 256,
|
||||
"\u222b": 258,
|
||||
"\u223c": 98,
|
||||
"\u2248": 210,
|
||||
"\u2264": 38,
|
||||
"\u2265": 74,
|
||||
"\u2266": 214,
|
||||
"\u2267": 181,
|
||||
"\u2295": 263,
|
||||
"\u22c5": 174,
|
||||
"\u22c6": 191,
|
||||
"\u22ee": 277,
|
||||
"\u22ef": 270,
|
||||
"\u2500": 205,
|
||||
"\u2551": 231,
|
||||
"\u25a0": 250,
|
||||
"\u25a1": 177,
|
||||
"\u25aa": 145,
|
||||
"\u25b2": 136,
|
||||
"\u25b3": 143,
|
||||
"\u25bc": 251,
|
||||
"\u25c6": 226,
|
||||
"\u25ca": 235,
|
||||
"\u25cb": 227,
|
||||
"\u25cf": 172,
|
||||
"\u25e6": 274,
|
||||
"\u2605": 204,
|
||||
"\u2606": 144,
|
||||
"\u2640": 133,
|
||||
"\u2642": 134,
|
||||
"\u2663": 252,
|
||||
"\u2666": 157,
|
||||
"\u266f": 221,
|
||||
"\u2713": 122,
|
||||
"\u2714": 249,
|
||||
"\u2717": 201,
|
||||
"\u2794": 168,
|
||||
"\u27a2": 276,
|
||||
"\u2a7d": 234,
|
||||
"\u2a7e": 241,
|
||||
"\u3008": 237,
|
||||
"\u3009": 238,
|
||||
"\ufeff": 257,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# ==================================================================================================
|
||||
configs = [test_config]
|
||||
|
||||
|
||||
def combine_checkpoint(save_dir):
|
||||
r"""
|
||||
Check if the checkpoint file is present as one part or 2 splits.
|
||||
Combine parts into one file if needed
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir : string
|
||||
The directory to check for checkpoint files or splits of it
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
0: The full checkpoint file already exists, no combine was needed
|
||||
1: The splits were found, a combine has been done
|
||||
-1: No full checkpoint and no splits exist. Error
|
||||
"""
|
||||
# Check if the full file already exists
|
||||
full_file_pattern = os.path.join(save_dir, "*.check")
|
||||
candidate = glob.glob(full_file_pattern)
|
||||
if len(candidate) == 1:
|
||||
print(
|
||||
"combine_checkpoint: The whole checkpoint file was found: {}".format(
|
||||
candidate[0]
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
# Check for splits
|
||||
splits_pattern = os.path.join(save_dir, "*.check.a[a-z]")
|
||||
splits = glob.glob(splits_pattern)
|
||||
splits.sort()
|
||||
if splits is None or len(splits) == 0:
|
||||
print(
|
||||
"combine_checkpoint: Both the full checkpoint and the splits are missing. Error"
|
||||
)
|
||||
return -1
|
||||
|
||||
# Combine splits
|
||||
full_fn = splits[0].rpartition(".check")[0] + ".check"
|
||||
with open(full_fn, "wb") as f_out:
|
||||
for split_fn in splits:
|
||||
with open(split_fn, "rb") as f_split:
|
||||
print("combine_checkpoint: read split: {}".format(split_fn))
|
||||
f_out.write(f_split.read())
|
||||
|
||||
print("combine_checkpoint: combine splits as: {}".format(full_fn))
|
||||
return 1
|
||||
|
||||
|
||||
def test_tf_predictor():
|
||||
r"""
|
||||
Test the TFPredictor
|
||||
"""
|
||||
viz = True
|
||||
|
||||
# Load the docling_api_data
|
||||
iocr_pages = []
|
||||
for table_json_fn, png_image_fn, table_bboxes_b in zip(
|
||||
docling_api_data["table_jsons"],
|
||||
docling_api_data["png_images"],
|
||||
docling_api_data["table_bboxes"],
|
||||
):
|
||||
with open(table_json_fn, "r") as fp:
|
||||
iocr_page_raw = json.load(fp)
|
||||
iocr_page = iocr_page_raw["pages"][0]
|
||||
iocr_page["image"] = cv2.imread(png_image_fn)
|
||||
# page_image = cv2.imread(png_image_fn)
|
||||
iocr_page["png_image_fn"] = png_image_fn
|
||||
iocr_page["table_bboxes"] = table_bboxes_b
|
||||
iocr_pages.append(iocr_page)
|
||||
|
||||
# Loop over the test configs
|
||||
for test_config in configs:
|
||||
# Check if the checkpoint file should be combined
|
||||
assert (
|
||||
combine_checkpoint(test_config["model"]["save_dir"]) >= 0
|
||||
), "Model checkpoint is missing"
|
||||
|
||||
# Loop over the iocr_pages
|
||||
predictor = TFPredictor(test_config)
|
||||
for iocr_page in iocr_pages:
|
||||
# Prepare "Predict" parameters
|
||||
# iw = iocr_page["width"]
|
||||
# ih = iocr_page["height"]
|
||||
# table_bboxes = [[0, 0, iw, ih]] # just one table per page in our examples
|
||||
table_bboxes = iocr_page["table_bboxes"]
|
||||
|
||||
# for t, table_bbox in enumerate(table_bboxes):
|
||||
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
||||
png_img_bfn = os.path.basename(iocr_page["png_image_fn"])
|
||||
print("Predicting image: {}".format(png_img_bfn))
|
||||
|
||||
# Run prediction, post-processing, and cell matching
|
||||
# PARAMETERS:
|
||||
# iocr_page - json received from iocr, augmented with iocr_page["image"]
|
||||
# table_bboxes - list of detected bboxes on page: [[x1, y1, x2, y2], [...]...]
|
||||
# do_matching - Boolean, when True - will match with text cells provided,
|
||||
# when False - returns original cell prediction BBOXes in the same format
|
||||
# OUTPUT:
|
||||
# List of dicts per table: [{"tf_responses":[...], "predict_details": {}}]
|
||||
|
||||
multi_tf_output = predictor.multi_table_predict(
|
||||
iocr_page, table_bboxes, False
|
||||
)
|
||||
|
||||
# Test output for validity, create visualizations...
|
||||
for t, tf_output in enumerate(multi_tf_output):
|
||||
tf_responses = tf_output["tf_responses"]
|
||||
predict_details = tf_output["predict_details"]
|
||||
assert tf_responses is not None, "Empty prediction response"
|
||||
assert isinstance(
|
||||
tf_responses, list
|
||||
), " Wrong response type. It should be a list"
|
||||
|
||||
img = Image.open(iocr_page["png_image_fn"])
|
||||
img1 = ImageDraw.Draw(img)
|
||||
|
||||
xt0 = table_bboxes[t][0]
|
||||
yt0 = table_bboxes[t][1]
|
||||
xt1 = table_bboxes[t][2]
|
||||
yt1 = table_bboxes[t][3]
|
||||
img1.rectangle(((xt0, yt0), (xt1, yt1)), outline="pink", width=5)
|
||||
|
||||
if viz:
|
||||
# Visualize original OCR words:
|
||||
for iocr_word in iocr_page["tokens"]:
|
||||
xi0 = iocr_word["bbox"]["l"]
|
||||
yi0 = iocr_word["bbox"]["t"]
|
||||
xi1 = iocr_word["bbox"]["r"]
|
||||
yi1 = iocr_word["bbox"]["b"]
|
||||
img1.rectangle(((xi0, yi0), (xi1, yi1)), outline="gray")
|
||||
# Visualize original docling_ibm_models.tableformer predictions:
|
||||
for predicted_bbox in predict_details["prediction_bboxes_page"]:
|
||||
xp0 = predicted_bbox[0] - 2
|
||||
yp0 = predicted_bbox[1] - 2
|
||||
xp1 = predicted_bbox[2] + 2
|
||||
yp1 = predicted_bbox[3] + 2
|
||||
img1.rectangle(((xp0, yp0), (xp1, yp1)), outline="green")
|
||||
|
||||
# Check the structure of the list items
|
||||
for i, response in enumerate(tf_responses):
|
||||
assert (
|
||||
"bbox" in response
|
||||
), "bbox field is missing from response: " + str(i)
|
||||
assert (
|
||||
"text_cell_bboxes" in response
|
||||
), "text_cell_bboxes is missing: " + str(i)
|
||||
assert (
|
||||
"row_span" in response
|
||||
), "row_span is missing from resp: " + str(i)
|
||||
assert (
|
||||
"col_span" in response
|
||||
), "col_span is missing from response: " + str(i)
|
||||
# print("*********** column_header: {}".format(response["column_header"]))
|
||||
if viz:
|
||||
# Visualization:
|
||||
for text_cell in response["text_cell_bboxes"]:
|
||||
xc0 = text_cell["l"]
|
||||
yc0 = text_cell["b"]
|
||||
xc1 = text_cell["r"]
|
||||
yc1 = text_cell["t"]
|
||||
img1.rectangle(((xc0, yc0), (xc1, yc1)), outline="red")
|
||||
|
||||
x0 = response["bbox"]["l"] - 6
|
||||
y0 = response["bbox"]["t"] - 6
|
||||
x1 = response["bbox"]["r"] + 6
|
||||
y1 = response["bbox"]["b"] + 6
|
||||
|
||||
if response["column_header"]:
|
||||
img1.rectangle(
|
||||
((x0, y0), (x1, y1)), outline="blue", width=5
|
||||
)
|
||||
elif response["row_header"]:
|
||||
img1.rectangle(
|
||||
((x0, y0), (x1, y1)), outline="magenta", width=5
|
||||
)
|
||||
elif response["row_section"]:
|
||||
img1.rectangle(
|
||||
((x0, y0), (x1, y1)), outline="brown", width=5
|
||||
)
|
||||
else:
|
||||
img1.rectangle(
|
||||
((x0, y0), (x1, y1)), outline="blue", width=1
|
||||
)
|
||||
if viz:
|
||||
viz_root = "./tests/test_data/viz/"
|
||||
Path(viz_root).mkdir(parents=True, exist_ok=True)
|
||||
png_img_bfn1 = png_img_bfn.replace(".png", "." + str(t) + ".png")
|
||||
viz_fn = os.path.join(viz_root, png_img_bfn1)
|
||||
img.save(viz_fn)
|
||||
# assert False
|
||||
|
||||
|
||||
def test_device_mode():
|
||||
r"""
|
||||
Test the "predict.device_mode" parameter
|
||||
"""
|
||||
mini_configs = [
|
||||
{"predict": {}},
|
||||
{"predict": {"device_mode": "cpu"}},
|
||||
{"predict": {"device_mode": "cuda"}},
|
||||
{"predict": {"device_mode": "gpu"}},
|
||||
{"predict": {"device_mode": "wrong"}},
|
||||
]
|
||||
|
||||
for i, config in enumerate(mini_configs):
|
||||
device = tf_predictor.decide_device(config)
|
||||
assert device in ["cpu", "cuda:0"], "Irrelevant device has been returned"
|
||||
|
||||
if i == 0:
|
||||
assert device == "cpu", "By default the 'cpu' device should be used"
|
||||
elif i == 1:
|
||||
assert device == "cpu", "An explicit 'cpu' device was given"
|
||||
elif i == 2 or i == 3:
|
||||
assert device == "cuda:0", "Cuda or gpu should become 'cuda:0'"
|
||||
else:
|
||||
assert (
|
||||
device == "cpu"
|
||||
), "A fall-back to 'cpu' should happen in case of error"
|
||||
Reference in New Issue
Block a user