Initial commit

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos
2024-07-15 10:52:19 +02:00
commit 7445296e6a
50 changed files with 13255 additions and 0 deletions
+60
View File
@@ -0,0 +1,60 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Tmp files and directories
stderr.*
stdout.*
*.tar
test.sh
OutputDecoder
jobs.txt
_std*.*
tests/tmp/*
runs/*
*.onnx
.DS_Store
viz/
# VIM
*.swp
*.swo
*.bak
# Environments
.env
.venv
_venv/
env/
venv/
ENV/
env.bak/
venv.bak/
venv
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# checkpoint file for testing
tests/test_data/model_artifacts/*.check
tests/test_data/model_artifacts/*.json
tests/test_data/model_artifacts/*.pt
# test results
tests/test_data/viz/
+43
View File
@@ -0,0 +1,43 @@
fail_fast: true
repos:
- repo: local
hooks:
- id: system
name: Black
entry: poetry run black docling_ibm_models
pass_filenames: false
language: system
files: '\.py$'
- repo: local
hooks:
- id: system
name: isort
entry: poetry run isort docling_ibm_models
pass_filenames: false
language: system
files: '\.py$'
- repo: local
hooks:
- id: system
name: Poetry check
entry: poetry lock --check
pass_filenames: false
language: system
# Ready to be enabled soon
# - repo: local
# hooks:
# - id: system
# name: flake8
# entry: poetry run flake8 docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
# - repo: local
# hooks:
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
+129
View File
@@ -0,0 +1,129 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement using
[deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com).
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html).
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
Homepage: [https://www.contributor-covenant.org](https://www.contributor-covenant.org)
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at
[https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations).
+163
View File
@@ -0,0 +1,163 @@
## Contributing In General
Our project welcomes external contributions. If you have an itch, please feel
free to scratch it.
To contribute code or documentation, please submit a [pull request](https://github.com/DS4SD/docling/pulls).
A good way to familiarize yourself with the codebase and contribution process is
to look for and tackle low-hanging fruit in the [issue tracker](https://github.com/DS4SD/docling/issues).
Before embarking on a more ambitious contribution, please quickly [get in touch](#communication) with us.
For general questions or support requests, please refer to the [discussion section](https://github.com/DS4SD/docling/discussions).
**Note: We appreciate your effort, and want to avoid a situation where a contribution
requires extensive rework (by you or by us), sits in backlog for a long time, or
cannot be accepted at all!**
### Proposing new features
If you would like to implement a new feature, please [raise an issue](https://github.com/DS4SD/docling/issues)
before sending a pull request so the feature can be discussed. This is to avoid
you wasting your valuable time working on a feature that the project developers
are not interested in accepting into the code base.
### Fixing bugs
If you would like to fix a bug, please [raise an issue](https://github.com/DS4SD/docling/issues) before sending a
pull request so it can be tracked.
### Merge approval
The project maintainers use LGTM (Looks Good To Me) in comments on the code
review to indicate acceptance. A change requires LGTMs from two of the
maintainers of each component affected.
For a list of the maintainers, see the [MAINTAINERS.md](MAINTAINERS.md) page.
## Legal
Each source file must include a license header for the MIT
Software. Using the SPDX format is the simplest approach.
e.g.
```
/*
Copyright IBM Inc. All rights reserved.
SPDX-License-Identifier: MIT
*/
```
We have tried to make it as easy as possible to make contributions. This
applies to how we handle the legal aspects of contribution. We use the
same approach - the [Developer's Certificate of Origin 1.1 (DCO)](https://github.com/hyperledger/fabric/blob/master/docs/source/DCO1.1.txt) - that the Linux® Kernel [community](https://elinux.org/Developer_Certificate_Of_Origin)
uses to manage code contributions.
We simply ask that when submitting a patch for review, the developer
must include a sign-off statement in the commit message.
Here is an example Signed-off-by line, which indicates that the
submitter accepts the DCO:
```
Signed-off-by: John Doe <john.doe@example.com>
```
You can include this automatically when you commit a change to your
local git repository using the following command:
```
git commit -s
```
## Communication
Please feel free to connect with us using the [discussion section](https://github.com/DS4SD/docling/discussions).
## Developing
### Usage of Poetry
We use Poetry to manage dependencies.
#### Install
To install, see the documentation here: https://python-poetry.org/docs/master/#installing-with-the-official-installer
1. Install the Poetry globally in your machine
```bash
curl -sSL https://install.python-poetry.org | python3 -
```
The installation script will print the installation bin folder `POETRY_BIN` which you need in the next steps.
2. Make sure Poetry is in your `$PATH`
- for `zsh`
```sh
echo 'export PATH="POETRY_BIN:$PATH"' >> ~/.zshrc
```
- for `bash`
```sh
echo 'export PATH="POETRY_BIN:$PATH"' >> ~/.bashrc
```
3. The official guidelines linked above include useful details on the configuration of autocomplete for most shell environments, e.g. Bash and Zsh.
#### Create a Virtual Environment and Install Dependencies
To activate the Virtual Environment, run:
```bash
poetry shell
```
To spawn a shell with the Virtual Environment activated. If the Virtual Environment doesn't exist, Poetry will create one for you. Then, to install dependencies, run:
```bash
poetry install
```
**(Advanced) Use a Specific Python Version**
If for whatever reason you need to work in a specific (older) version of Python, run:
```bash
poetry env use $(which python3.11)
```
This creates a Virtual Environment with Python 3.11. For other versions, replace `$(which python3.11)` by the path to the interpreter (e.g., `/usr/bin/python3.11`) or use `$(which pythonX.Y)`.
#### Add a new dependency
```bash
poetry add NAME
```
## Coding style guidelines
We use the following tools to enforce code style:
- iSort, to sort imports
- Black, to format code
We run a series of checks on the code base on every commit, using `pre-commit`. To install the hooks, run:
```bash
pre-commit install
```
To run the checks on-demand, run:
```
pre-commit run --all-files
```
Note: Checks like `Black` and `isort` will "fail" if they modify files. This is because `pre-commit` doesn't like to see files modified by their Hooks. In these cases, `git add` the modified files and `git commit` again.
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) [year] [fullname]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+10
View File
@@ -0,0 +1,10 @@
# MAINTAINERS
- Christoph Auer - [@cau-git](https://github.com/cau-git)
- Michele Dolfi - [@dolfim-ibm](https://github.com/dolfim-ibm)
- Maxim Lysak - [@maxmnemonic](https://github.com/maxmnemonic)
- Nikos Livathinos - [@nikos-livathinos](https://github.com/nikos-livathinos)
- Ahmed Nassar - [@nassarofficial](https://github.com/nassarofficial)
- Peter Staar - [@PeterStaar-IBM](https://github.com/PeterStaar-IBM)
Maintainers can be contacted at [deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com).
+127
View File
@@ -0,0 +1,127 @@
# Docling-models
AI modules to support the Dockling PDF document conversion project.
- TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
- Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
## Installation Instructions
### MacOS / Linux
To install `poetry` locally, use either `pip` or `homebrew`.
To install `poetry` on a docker container, do the following:
```
ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=false
# Install poetry
RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
&& python install-poetry.py \
&& poetry --version \
&& rm install-poetry.py
```
To install and run the package, simply set up a poetry environment
```
poetry env use $(which python3.11)
poetry shell
```
and install all the dependencies,
```
poetry install # this will only install the deps from the poetry.lock
poetry install --no-dev # this will skip installing dev dependencies
```
To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
```
poetry update
```
## Pipeline Overview
![Architecture](docs/tablemodel_overview_color.png)
## Datasets
Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
| Name | Description | URL |
| ------------- |:-------------:|----|
| PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
| FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
| TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
## Models
### TableModel04:
![TableModel04](docs/tbm04.png)
**TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
## Configuration file
Example configuration can be seen inside test `tests/test_tf_predictor.py`
These are the main sections of the configuration file:
- `dataset`: The directory for prepared data and the parameters used during the data loading.
- `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
trained checkpoint files.
- `train`: Parameters for the training of the model.
- `predict`: Parameters for the evaluation of the model.
- `dataset_wordmap`: Very important part that contains token maps.
## Model weights
You can download the model weights and config files from the links:
- [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
- [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
Place the downloaded files into `tests/test_data/model_artifacts/` directory.
## Inference Tests
This contains unit tests for Docling models.
First download the model weights (see above), then run:
```
./devtools/check_code.sh
```
This will also generate prediction and matching visualizations that can be found here:
`tests\test_data\viz\`
Visualization outlines:
- `Light Pink`: border of recognized table
- `Grey`: OCR cells
- `Green`: prediction bboxes
- `Red`: OCR cells matched with prediction
- `Blue`: Post processed, match
- `Bold Blue`: column header
- `Bold Magenta`: row header
- `Bold Brown`: section row (if table have one)
## Demo
A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
`png` images and visualize the predictions inside another directory `<viz_dir>`.
First download the model weights (see above), then run:
```
python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
```
e.g.
```
python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
```
+126
View File
@@ -0,0 +1,126 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import argparse
import logging
import os
import sys
import time
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
ARTIFACT_PATH = "tests/test_data/model_artifacts"
def demo(
logger: logging.Logger,
artifact_path: str,
num_threads: int,
img_dir: str,
viz_dir: str,
):
r"""
Apply LayoutPredictor on the input image directory
If you want to load from PDF:
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
"""
# Create the layout predictor
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
# Predict all test png images
for img_fn in Path(img_dir).rglob("*.png"):
logger.info("Predicting '%s'...", img_fn)
start_t = time.time()
with Image.open(img_fn) as image:
# Predict layout
preds = list(lpredictor.predict(image))
dt_ms = 1000 * (time.time() - start_t)
logger.debug("Time elapsed for prediction(ms): %s", dt_ms)
# Draw predictions
out_img = image.copy()
draw = ImageDraw.Draw(out_img)
for i, pred in enumerate(preds):
scr = pred["confidence"]
lab = pred["label"]
box = [
round(pred["l"]),
round(pred["t"]),
round(pred["r"]),
round(pred["b"]),
]
if lab == "Table":
draw.rectangle(
box,
outline="red",
)
draw.text(
(box[0], box[1]),
text=str(lab),
fill="blue",
)
logger.info("Table %s: bbox=%s", i, box)
save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
out_img.save(save_fn)
logger.info("Saving prediction visualization in: '%s'", save_fn)
def main(args):
r""" """
num_threads = int(args.num_threads) if args.num_threads is not None else None
img_dir = args.img_dir
viz_dir = args.viz_dir
# Initialize logger
logger = logging.getLogger("LayoutPredictor")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
# Ensure the viz dir
Path(viz_dir).mkdir(parents=True, exist_ok=True)
# Test the LayoutPredictor
demo(logger, ARTIFACT_PATH, num_threads, img_dir, viz_dir)
if __name__ == "__main__":
r"""
python -m demo.demo_layout_predictor -i <images_dir>
"""
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
parser.add_argument(
"-n", "--num_threads", required=False, default=None, help="Number of threads"
)
parser.add_argument(
"-i",
"--img_dir",
required=True,
help="PNG images input directory",
)
parser.add_argument(
"-v",
"--viz_dir",
required=False,
default="viz/",
help="Directory to save prediction visualizations",
)
args = parser.parse_args()
main(args)
+96
View File
@@ -0,0 +1,96 @@
#!/bin/bash
set -e
# Disabled pylint messages
# C0114, C0116: Missing module docstring (missing-module-docstring)
# C0209: Formatting a regular string which could be a f-string (consider-using-f-string)
# C0103: Variable name doesn't conform to snake_case naming style (invalid-name)
# R0801: Similar lines in %s files %s
# W0621: Redefining name from outer scope
# W1514: Using open without explicitly specifying an encoding (unspecified-encoding)
# R0912: Too many branches (too-many-branches)
# R0913: Too many arguments
# R0914: Too many local variables (too-many-locals)
# R0915: Too many statements (too-many-statements)
# R1702: Too many nested blocks (too-many-nested-blocks)
# R0902: Too many instance attributes
# R0903: Too few public methods
# W0221: Arguments differ
# C0415: Import outside toplevel
# C0302: Too many lines in module
# W0718: Catching too general exception Exception
# R0902: Too many instance attributes
# R1702: Too many nested blocks
PYLINT_DISABLED="C0114,C0116,C0209,C0103,R0801,W0621,W1514,R0912,R0913,R0914,R0915,R1702"
PYLINT_DISABLED+=",R0902,R0903,W0221,C0415,C0302,R0401,W0718,R0902,R1702"
readonly MAX_LINE_LENGTH=100
readonly INDENT_SPACES=4
##########################################################################################
# Functions
#
Usage() {
echo "Check codebase with "
echo "Usage:"
echo "$0 [-c] [-h]"
echo
echo "-c: Clear cache before invoking PyTest"
echo "-h: Print this help message"
echo
echo "$0"
}
##########################################################################################
# Main
#
clear_cache=0
while getopts ":hc" option; do
case "${option}" in
c ) clear_cache=1;;
h ) Usage; exit;;
\? ) Usage; exit;;
: ) # Missing required argument
Usage; exit;;
esac
done
# PEP8
echo "Flake8 check:"
flake8 \
--max-line-length=${MAX_LINE_LENGTH} \
--indent-size=${INDENT_SPACES} \
--ignore=E121,E123,E126,E226,E24,E704,W503,W504,W605,E203 \
--extend-exclude '_*' \
docling_ibm_models/ tests/
echo "Flake8 - OK"
echo
# # Pylint
# echo "Pylint check:"
# indent_string=$(printf '%*s' ${INDENT_SPACES} "" | tr ' ' 'n' | tr 'n' ' ')
# # echo "indent_string: '${indent_string}'"
# pylint \
# --max-line-length ${MAX_LINE_LENGTH} \
# --indent-string "${indent_string}" \
# --disable ${PYLINT_DISABLED} \
# --extension-pkg-whitelist='pydantic' \
# --ignore-patterns '[!_]' \
# docling_ibm_models/ tests/
#
# echo "Pylint check - OK"
# echo
# Unit tests with PyTest
echo "PyTest:"
if [ ${clear_cache} -eq 1 ]; then
echo "Clear pytest cache first"
echo
python -m pytest -n auto --cache-clear --ignore=docling_ibm_models/ tests/
else
python -m pytest -n auto --ignore=docling_ibm_models/ tests/
fi
echo "PyTest check - OK"
@@ -0,0 +1,171 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import os
from collections.abc import Iterable
from typing import Union
import numpy as np
import onnxruntime as ort
from PIL import Image
MODEL_CHECKPOINT_FN = "model.pt"
DEFAULT_NUM_THREADS = 4
# Classes:
CLASSES_MAP = {
0: "background",
1: "Caption",
2: "Footnote",
3: "Formula",
4: "List-item",
5: "Page-footer",
6: "Page-header",
7: "Picture",
8: "Section-header",
9: "Table",
10: "Text",
11: "Title",
12: "Document Index",
13: "Code",
14: "Checkbox-Selected",
15: "Checkbox-Unselected",
16: "Form",
17: "Key-Value Region",
}
class LayoutPredictor:
r"""
Document layout prediction using ONNX
"""
def __init__(
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
):
r"""
Provide the artifact path that contains the LayoutModel ONNX file
The number of threads is decided, in the following order, by:
1. The init method parameter `num_threads`, if it is set.
2. The envvar "OMP_NUM_THREADS", if it is set.
3. The default value DEFAULT_NUM_THREADS.
The execution provided is decided, in the following order:
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
it uses the "CPUExecutionProvider".
3. Otherwise if the "CUDAExecutionProvider" is present, use:
["CUDAExecutionProvider", "CPUExecutionProvider"]:
Parameters
----------
artifact_path: Path for the model ONNX file.
num_threads: (Optional) Number of threads to run the inference.
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
Raises
------
FileNotFoundError when the model's ONNX file is missing
"""
# Set basic params
self._threshold = 0.6 # Score threshold
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
# Get env vars
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
if num_threads is None:
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
self._num_threads = num_threads
# Decide the execution providers
if (
not self._use_cpu_only
and "CUDAExecutionProvider" in ort.get_available_providers()
):
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self._providers = providers
# Model ONNX file
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._onnx_fn):
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
# ONNX options
self._options = ort.SessionOptions()
self._options.intra_op_num_threads = self._num_threads
self.sess = ort.InferenceSession(
self._onnx_fn,
sess_options=self._options,
providers=self._providers,
)
def info(self) -> dict:
r"""
Get information about the configuration of LayoutPredictor
"""
info = {
"onnx_file": self._onnx_fn,
"intra_op_num_threads": self._num_threads,
"use_cpu_only": self._use_cpu_only,
"providers": self._providers,
"image_size": self._image_size,
"threshold": self._threshold,
}
return info
def predict(self, orig_img: Union[Image, np.array]) -> Iterable[dict]:
r"""
Predict bounding boxes for a given image.
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
[left, top, right, bottom]
Parameter
---------
origin_img: Image to be predicted as a PIL Image object or numpy array.
Yield
-----
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
Raises
------
TypeError when the input image is not supported
"""
# Convert image format
if isinstance(orig_img, Image.Image):
page_img = orig_img.convert("RGB")
elif isinstance(orig_img, np.ndarray):
page_img = Image.fromarray(orig_img).convert("RGB")
else:
raise TypeError("Not supported input image format")
w, h = page_img.size
page_img = page_img.resize((self._image_size, self._image_size))
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
# Predict
labels, boxes, scores = self.sess.run(
output_names=None,
input_feed={
"images": page_data,
"orig_target_sizes": self._size,
},
)
# Yield output
for label, box, score in zip(labels[0], boxes[0], scores[0]):
if score > self._threshold:
yield {
"l": box[0] / self._image_size * w,
"t": box[1] / self._image_size * h,
"r": box[2] / self._image_size * w,
"b": box[3] / self._image_size * h,
"label": CLASSES_MAP[label],
"confidence": score,
}
+200
View File
@@ -0,0 +1,200 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import argparse
import json
import logging
import os
import torch
import docling_ibm_models.tableformer.settings as s
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
LOG_LEVEL = logging.DEBUG
logger = s.get_custom_logger("common", LOG_LEVEL)
def validate_config(config):
r"""
Validate the provided configuration file.
A ValueError exception will be thrown in case the config file is invalid
Parameters
----------
config : dictionary
Configuration for the tablemodel
Returns
-------
bool : True on success
"""
if "model" not in config:
return True
if "preparation" not in config:
return True
assert (
"max_tag_len" in config["preparation"]
), "Config error: 'preparation.max_tag_len' parameter is missing"
if "seq_len" in config["model"]:
assert (
config["model"]["seq_len"] > 0
), "Config error: 'model.seq_len' should be positive"
assert config["model"]["seq_len"] <= (
config["preparation"]["max_tag_len"] + 2
), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2"
return True
def parse_arguments():
r"""
Parse the input arguments
A ValueError exception will be thrown in case the config file is invalid
"""
parser = argparse.ArgumentParser(description="Train the TableModel")
parser.add_argument(
"-c", "--config", required=True, default=None, help="configuration file (JSON)"
)
args = parser.parse_args()
config_filename = args.config
assert os.path.isfile(config_filename), "FAILURE: Config file not found."
return read_config(config_filename)
def read_config(config_filename):
with open(config_filename, "r") as fd:
config = json.load(fd)
# Validate the config file
validate_config(config)
return config
def safe_get_parameter(input_dict, index_path, default=None, required=False):
r"""
Safe get parameter from a nested dictionary.
Provide a nested dictionary (dictionary of dictionaries) and a list of indices:
- If the whole index path exists the value pointed by it is returned
- Otherwise the default value is returned.
Input:
input_dict: Data structure of nested dictionaries.
index_path: List with the indices path to follow inside the input_dict.
default: Default value to return if the indices path is broken.
required: If true a ValueError exception will be raised in case the parameter does not exist
Output:
The value pointed by the index path or "default".
"""
if input_dict is None or index_path is None:
return default
d = input_dict
for i in index_path[:-1]:
if i not in d:
if required:
raise ValueError("Missing parameter: {}".format(i))
return default
d = d[i]
last_index = index_path[-1]
if last_index not in d:
if required:
raise ValueError("Missing parameter: {}".format(last_index))
return default
return d[last_index]
def get_prepared_data_filename(prepared_data_part, dataset_name):
r"""
Build the full filename of the prepared data part
Parameters
----------
prepared_data_part : string
Part of the prepared data
dataset_name : string
Name of the dataset
Returns
-------
string
The full filename for the prepared file
"""
template = s.PREPARED_DATA_PARTS[prepared_data_part]
if "<POSTFIX>" in template:
template = template.replace("<POSTFIX>", dataset_name)
return template
def create_dataset_and_model(config, purpose, fixed_padding=False):
r"""
Gets a model from configuration
Parameters
---------
config : Dictionary
The configuration of the model
purpose : string
One of "train", "eval", "predict"
fixed_padding : bool
Parameter passed to the constructor of the DataLoader
Returns
-------
In case a Model cannot be initialized return None, None, None. Otherwise:
device : selected device
dataset : Instance of the DataLoader
model : Instance of the model
"""
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
model_type = config["model"]["type"]
model = None
# Get env vars:
use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
use_cuda_only = not use_cpu_only
# Use the cpu for the evaluation
device = "cpu" # Default, run on CPU
num_gpus = torch.cuda.device_count() # Check if GPU is available
if use_cuda_only:
device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
else:
device = "cpu"
# Create the DataLoader
# loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
dataset.set_device(device)
dataset_val = None
if config["train"]["validation"] and purpose == "train":
dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
dataset_val.set_device(device)
if model_type == "TableModel04_rs":
from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
TableModel04_rs,
)
# Find the model class and create an instance of it
for candidate in BaseModel.__subclasses__():
if candidate.__name__ == model_type:
init_data = dataset.get_init_data()
model = candidate(config, init_data, purpose, device)
if model is None:
logger.warn("Not found model: " + str(model_type))
return None, None, None
logger.info("Found model: " + str(model_type))
if purpose == s.PREDICT_PURPOSE:
return device, dataset, model
else:
return device, dataset, dataset_val, model
@@ -0,0 +1,504 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import copy
import logging
import os
import random
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import docling_ibm_models.tableformer.data_management.transforms as T
import docling_ibm_models.tableformer.settings as s
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
class DataTransformer:
r"""
Data transformations for the images and bboxes
Check the "help" fields inside the config file for an explanation of each parameter
"""
def __init__(self, config):
self._config = config
print("DataTransformer Init!")
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def append_id(self, filename):
name, ext = os.path.splitext(filename)
return "{name}_{uid}{ext}".format(name=name, uid="resized", ext=ext)
def load_image(self, img_fn):
r"""
Load an image from the disk
Parameters
----------
img_fn: The filename of the image
Returns
-------
PIL image object
"""
img = Image.open(img_fn)
return img
def load_image_cv2(self, img_fn):
r"""Load an image from the disk
Parameters
----------
img_fn: The filename of the image
Returns
-------
CV2 image object
"""
img = cv2.imread(img_fn)
return img
def save_image(self, img, img_fn):
img.save(self.append_id(img_fn))
def renderbboxes(self, img, bboxes):
draw_img = ImageDraw.Draw(img)
for i in range(len(bboxes)):
draw_img.rectangle(bboxes[i], fill=None, outline=(255, 0, 0))
return img
def get_dataset_settings(self):
dataset = {}
debug = {"save_debug_images": False}
if "dataset" in self._config:
dataset = self._config["dataset"]
if "debug" in self._config:
debug = self._config["debug"]
return dataset, debug
def _prepare_image_from_file(self, image_fn, bboxes, convert_box=True):
r"""
Load the image from file and prepare it
Parameters
----------
image_fn : string
Filename to load the image
bboxes : dict
Bounding boxes of the image
convert_box : bool
If true the bboxes are converted to xcycwh format
Returns
-------
PIL image
A PIL image object with the image prepared according to the settings in the config file
bboxes : dict
Converted bboxes of the image
"""
im = self.load_image(image_fn)
return self._prepare_image(im, bboxes, convert_box, image_fn)
def _prepare_image(self, im, bboxes, convert_box=True, image_fn=None):
r"""
Parameters
----------
im : PIL image object
bboxes : dict
Bounding boxes of the image
convert_box : bool
If true the bboxes are converted to xcycwh format
image_fn : string
Filename of the original image or None. It is used to save augmented image for debugging
Returns
-------
PIL image
A PIL image object with the image prepared according to the settings in the config file
bboxes : dict
Converted bboxes of the image
"""
debug_settings = False
settings, debug_settings = self.get_dataset_settings()
desired_size = settings["resized_image"]
old_size = im.size
# Calculate Aspect Ratio if needed
if settings["keep_AR"]:
ratio = float(desired_size) / max(old_size)
else:
ratio = 1 # Square image
new_size = old_size
if max(old_size) < desired_size:
# Image is smaller than desired
# Upscale?
if settings["up_scaling_enabled"]:
# Calculate new image size, taking into account aspect ratio
new_size = tuple([int(x * ratio) for x in old_size])
else:
new_size = old_size
else:
if settings["keep_AR"]:
new_size = tuple([int(x * ratio) for x in old_size])
else:
new_size = [desired_size, desired_size]
if not settings["keep_AR"]:
if settings["up_scaling_enabled"]:
new_size = [desired_size, desired_size]
######################################################################################
# Use OpenCV to resize the image
#
# im = im.resize(new_size, Image.ANTIALIAS)
import cv2
np_im = np.array(im)
np_resized = cv2.resize(np_im, new_size, interpolation=cv2.INTER_LANCZOS4)
im = Image.fromarray(np_resized)
######################################################################################
new_bboxes = copy.deepcopy(bboxes)
# Resize bboxes (in pixels)
x_scale = new_size[0] / old_size[0]
y_scale = new_size[1] / old_size[1]
# loop over bboxes
for i in range(len(new_bboxes)):
new_bboxes[i][0] = x_scale * bboxes[i][0]
new_bboxes[i][1] = y_scale * bboxes[i][1]
new_bboxes[i][2] = x_scale * bboxes[i][2]
new_bboxes[i][3] = y_scale * bboxes[i][3]
# Set background color for padding
br = settings["padding_color"][0]
bg = settings["padding_color"][1]
bb = settings["padding_color"][2]
bcolor = (br, bg, bb)
# Create empty canvas of background color and desired size
new_im = Image.new(mode="RGB", size=(desired_size, desired_size), color=bcolor)
if "grayscale" in settings:
if settings["grayscale"]:
im = im.convert("LA")
if settings["padding_mode"] == "frame":
# If paddinds are around image, paste resized image in the center
x_pad = (desired_size - new_size[0]) // 2
y_pad = (desired_size - new_size[1]) // 2
# Paste rescaled image
new_im.paste(im, (x_pad, y_pad))
# Offset (pad) bboxes
# loop over bboxes
for i in range(len(new_bboxes)):
new_bboxes[i][0] += x_pad
new_bboxes[i][1] += y_pad
new_bboxes[i][2] += x_pad
new_bboxes[i][3] += y_pad
else:
# Otherwise paste in the 0,0 coordinates
new_im.paste(im, (0, 0))
if debug_settings:
if debug_settings["save_debug_images"]:
aug_im = self.renderbboxes(new_im, bboxes)
if "grayscale" in settings:
if settings["grayscale"]:
aug_im = aug_im.convert("LA")
self.save_image(aug_im, image_fn)
if convert_box:
bboxes = self.xyxy_to_xcycwh(new_bboxes, desired_size)
return new_im, bboxes
# convert bboxes from [x1, y1, x2, y2] format to [xc, yc, w, h] format
def xyxy_to_xcycwh(self, bboxes, size):
# Use the "dataset.bbox_format" parameter to decide which bbox format to use
bbox_format = self._config["dataset"].get("bbox_format", "4plet")
conv_bboxes = []
for i in range(len(bboxes)):
x1 = bboxes[i][0] / size # X1
y1 = bboxes[i][1] / size # Y1
x2 = bboxes[i][2] / size # X2
y2 = bboxes[i][3] / size # Y2
xc = (x1 + x2) / 2
yc = (y1 + y2) / 2
bw = abs(x2 - x1)
bh = abs(y2 - y1)
if bbox_format == "5plet":
cls = bboxes[i][4]
conv_bboxes.append([xc, yc, bw, bh, cls])
else:
conv_bboxes.append([xc, yc, bw, bh])
# conv_bboxes = bboxes
return conv_bboxes
def rescale_in_memory(self, image, normalization):
r"""
Receive image and escale it in memory
Parameters
----------
image : PIL image
The image data to rescale
normalization : dictionary
The normalization information with the format:
"state": "true or false if image normalization is to be enabled",
"mean": "mean values to use if state is true",
"std": "std values to use if state is true"
Returns
-------
npimgc : FloatTensor
The loaded and properly transformed image data
"""
settings, debug_settings = self.get_dataset_settings()
new_image, _ = self._prepare_image(image, {}, convert_box=False, image_fn=None)
# Convert to nparray
npimg = np.asarray(new_image) # (width, height, channels)
# Convert to float?
npimgc = npimg.copy()
# Transpose numpy array (image)
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
npimgc = torch.FloatTensor(npimgc / 255.0)
if normalization:
transform = transforms.Compose(
[
transforms.Normalize(
mean=self._config["dataset"]["image_normalization"]["mean"],
std=self._config["dataset"]["image_normalization"]["std"],
)
]
)
npimgc = transform(npimgc)
return npimgc
def _rescale(self, image_fn, bboxes, normalization):
r"""
Rescale, resize, pad the given image and its associated bboxes according to the settings
from the config
Parameters
----------
image_fn: full image file name
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
bottom-left points
statistics: Dictionary with statistics over the whole image dataset.
The keys are: "mean", "variance", "std" and each value is a list with the
coresponding statistical value for each channel. Normally there are 3
channels.
Returns
-------
npimgc : FloatTensor
The loaded and properly transformed image data
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
center, w, h the width and height of the bbox and all are normalized to the
scaled size of the image
Raises
------
ValueError
In case the configuration and the image dimensions make it impossible to rescale the
image throw a ValueError exception
"""
settings, debug_settings = self.get_dataset_settings()
# new_image is a PIL object
new_image, new_bboxes = self._prepare_image_from_file(image_fn, bboxes)
# Convert to nparray
npimg = np.asarray(new_image) # (width, height, channels)
# Convert to float?
npimgc = npimg.copy()
# Transpose numpy array (image)
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
npimgc = torch.FloatTensor(npimgc / 255.0)
if normalization:
transform = transforms.Compose(
[
transforms.Normalize(
mean=self._config["dataset"]["image_normalization"]["mean"],
std=self._config["dataset"]["image_normalization"]["std"],
)
]
)
npimgc = transform(npimgc)
return npimgc, new_bboxes
def rescale_old(self, image, bboxes, statistics=None):
r"""
Rescale, resize, pad the given image and its associated bboxes according to the settings
from the config
Input:
image: np array (channels, width, height)
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
bottom-left points
statistics: Dictionary with statistics over the whole image dataset.
The keys are: "mean", "variance", "std" and each value is a list with the
coresponding statistical value for each channel. Normally there are 3
channels.
Output:
image: np array (channels, resized_image, resized_image)
bboxes: List with bboxes in the format (xc, yc, w, h) where xc, yc are the coords of the
center, w, h the width and height of the bbox and all are normalized to the
scaled size of the image
Exceptions:
In case the configuration and the image dimensions make it impossible to rescale the
image throw a ValueError exception
"""
image_size = 448
# Convert the image to (width, height, channels)
image = image.transpose(1, 2, 0)
# Convert to PIL format and resize
image = Image.fromarray(image)
image = image.resize((image_size, image_size), Image.ANTIALIAS)
return image, bboxes
def rescale_batch(self, images, bboxes, statistics=None):
r"""
Rescale, resize, pad the given batch of images and its associated bboxes according to the
settings from the config.
Input:
images: np array (batch_size, channels, width, height)
bboxes:
statistics: Dictionary with statistics over the whole image dataset.
The keys are: "mean", "variance", "std" and each value is a list with the
coresponding statistical value for each channel. Normally there are 3
channels.
Output:
image batch: np array (batch_size, channels, resized_image, resized_image)
bboxes:
Exceptions:
In case the configuration and the image dimensions make it impossible to rescale the
image throw a ValueError exception
"""
pass
def sample_preprocessor(self, image_fn, bboxes, purpose, table_bboxes=None):
r"""
Rescale, resize, pad the given image and its associated bboxes according to the settings
from the config
Parameters
----------
image_fn: full image file name
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
bottom-left points
statistics: Dictionary with statistics over the whole image dataset.
The keys are: "mean", "variance", "std" and each value is a list with the
coresponding statistical value for each channel. Normally there are 3
channels.
Returns
-------
npimgc : FloatTensor
The loaded and properly transformed image data
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
center, w, h the width and height of the bbox and all are normalized to the
scaled size of the image
Raises
------
ValueError
In case the configuration and the image dimensions make it impossible to rescale the
image throw a ValueError exception
"""
settings, debug_settings = self.get_dataset_settings()
img = self.load_image_cv2(image_fn)
img = np.ascontiguousarray(img)
target = {
"size": [img.shape[1], img.shape[2]],
"boxes": (
torch.from_numpy(np.array(bboxes)[:, :4])
if purpose != s.PREDICT_PURPOSE
else None
),
"classes": (
np.array(bboxes)[:, -1] if purpose != s.PREDICT_PURPOSE else None
),
"area": img.shape[1] * img.shape[2],
}
optional_transforms = [T.NoTransformation()]
# Necessary preprocessing ends here, experimental options begin below.
# DETR format, might be necessary to keep this structure to share other functions used by
# the community
if purpose == s.TRAIN_PURPOSE:
if self._config["dataset"]["color_jitter"]:
jitter = T.ColorJitter(
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
)
optional_transforms.append(jitter)
if self._config["dataset"]["rand_pad"]:
pad_val = random.randint(0, 50)
rand_pad = T.RandomPad(pad_val)
optional_transforms.append(rand_pad)
if table_bboxes is not None:
if self._config["dataset"]["rand_crop"]:
w_, h_, _ = img.shape[0], img.shape[1], img.shape[2]
w_c, h_c = table_bboxes[0], table_bboxes[1]
f_w, f_h = random.randint(0, w_c), random.randint(0, h_c)
rand_crop = T.RandomCrop((w_, h_), (f_w, f_h))
optional_transforms.append(rand_crop)
# transform_opt = random.choice(optional_transforms)
normalize = T.Normalize(
mean=self._config["dataset"]["image_normalization"]["mean"],
std=self._config["dataset"]["image_normalization"]["std"],
)
resized_size = self._config["dataset"]["resized_image"]
resize = T.Resize([resized_size, resized_size])
transformations = T.RandomChoice(optional_transforms)
img, target = transformations(img, target)
img, target = normalize(img, target)
img, target = resize(img, target)
img = img.transpose(2, 1, 0) # (channels, width, height)
img = torch.FloatTensor(img / 255.0)
bboxes_ = target["boxes"]
classes_ = target["classes"]
desired_size = img.shape[1]
if purpose != s.PREDICT_PURPOSE:
bboxes_ = np.concatenate(
(bboxes_, np.expand_dims(classes_, axis=1)), axis=1
)
bboxes_ = self.xyxy_to_xcycwh(bboxes_, desired_size)
return img, bboxes_
@@ -0,0 +1,574 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import numbers
from collections.abc import Iterable, Sequence
import cv2
import numpy as np
import torch
from torchvision.transforms import functional
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
INTER_MODE = {
"NEAREST": cv2.INTER_NEAREST,
"BILINEAR": cv2.INTER_LINEAR,
"BICUBIC": cv2.INTER_CUBIC,
}
PAD_MOD = {
"constant": cv2.BORDER_CONSTANT,
"edge": cv2.BORDER_REPLICATE,
"reflect": cv2.BORDER_DEFAULT,
"symmetric": cv2.BORDER_REFLECT,
}
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
def to_tensor(pic):
"""Converts a numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
Args:
pic (np.ndarray, torch.Tensor): Image to be converted to tensor, (H x W x C[RGB]).
Returns:
Tensor: Converted image.
"""
if _is_numpy_image(pic):
if len(pic.shape) == 2:
pic = cv2.cvtColor(pic, cv2.COLOR_GRAY2RGB)
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor) or img.max() > 1:
return img.float().div(255)
else:
return img
elif _is_tensor_image(pic):
return pic
else:
try:
return to_tensor(np.array(pic))
except Exception:
raise TypeError("pic should be ndarray. Got {}".format(type(pic)))
def to_cv_image(pic, mode=None):
"""Convert a tensor to an ndarray.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (str): color space and pixel depth of input data (optional)
for example: cv2.COLOR_RGB2BGR.
Returns:
np.array: Image converted to PIL Image.
"""
if not (_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic)))
npimg = pic
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.squeeze(np.transpose(pic.numpy(), (1, 2, 0)))
if not isinstance(npimg, np.ndarray):
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray")
if mode is None:
return npimg
else:
return cv2.cvtColor(npimg, mode)
def normalize(tensor, mean, std):
"""Normalize a tensor image with mean and standard deviation.
See ``Normalize`` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channely.
Returns:
Tensor: Normalized Tensor image.
"""
if _is_tensor_image(tensor):
for t, m, s in zip(tensor, mean, std, strict=False):
t.sub_(m).div_(s)
return tensor
elif _is_numpy_image(tensor):
return (tensor.astype(np.float32) - 255.0 * np.array(mean)) / np.array(std)
else:
raise RuntimeError("Undefined type")
def resize(img, size, interpolation="BILINEAR"):
"""Resize the input CV Image to the given size.
Args:
img (np.ndarray): Image to be resized.
size (tuple or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing
the aspect ratio. i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (str, optional): Desired interpolation. Default is ``BILINEAR``
Returns:
cv Image: Resized image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError("Got inappropriate size arg: {}".format(size))
if isinstance(size, int):
h, w, c = img.shape
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return cv2.resize(
img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation]
)
else:
oh = size
ow = int(size * w / h)
return cv2.resize(
img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation]
)
else:
oh, ow = size
return cv2.resize(
img, dsize=(int(ow), int(oh)), interpolation=INTER_MODE[interpolation]
)
def to_rgb_bgr(pic):
"""Converts a color image stored in BGR sequence to RGB (BGR to RGB)
or stored in RGB sequence to BGR (RGB to BGR).
Args:
pic (np.ndarray, torch.Tensor): Image to be converted, (H x W x 3).
Returns:
Tensor: Converted image.
"""
if _is_numpy_image(pic) or _is_tensor_image(pic):
img = pic[:, :, [2, 1, 0]]
return img
else:
try:
return to_rgb_bgr(np.array(pic))
except Exception:
raise TypeError("pic should be numpy.ndarray or torch.Tensor.")
def pad(img, padding, fill=(0, 0, 0), padding_mode="constant"):
"""Pad the given CV Image on all sides with speficified padding mode and fill value.
Args:
img (np.ndarray): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int, tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
constant: pads with a constant value, this value is specified with fill
edge: pads with the last value on the edge of the image
reflect: pads with reflection of image (without repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
symmetric: pads with reflection of image (repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
CV Image: Padded image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
if not isinstance(padding, (numbers.Number, tuple)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple")
assert padding_mode in [
"constant",
"edge",
"reflect",
"symmetric",
], "Padding mode should be either constant, edge, reflect or symmetric"
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, Sequence) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, Sequence) and len(padding) == 4:
pad_left, pad_top, pad_right, pad_bottom = padding
if isinstance(fill, numbers.Number):
fill = (fill,) * (2 * len(img.shape) - 3)
if padding_mode == "constant":
assert (len(fill) == 3 and len(img.shape) == 3) or (
len(fill) == 1 and len(img.shape) == 2
), "channel of image is {} but length of fill is {}".format(
img.shape[-1], len(fill)
)
img = cv2.copyMakeBorder(
src=img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=PAD_MOD[padding_mode],
value=fill,
)
return img
def crop(img, x, y, h, w):
"""Crop the given CV Image.
Args:
img (np.ndarray): Image to be cropped.
x: Upper pixel coordinate.
y: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
Returns:
CV Image: Cropped image.
"""
assert _is_numpy_image(img), "img should be CV Image. Got {}".format(type(img))
assert h > 0 and w > 0, "h={} and w={} should greater than 0".format(h, w)
x1, y1, x2, y2 = round(x), round(y), round(x + h), round(y + w)
# try:
# check_point1 = img[x1, y1, ...]
# check_point2 = img[x2-1, y2-1, ...]
# except IndexError:
# img = cv2.copyMakeBorder(img, - min(0, x1), max(x2 - img.shape[0], 0),
# -min(0, y1), max(y2 - img.shape[1], 0),
# cv2.BORDER_CONSTANT, value=[0, 0, 0])
# y2 += -min(0, y1)
# y1 += -min(0, y1)
# x2 += -min(0, x1)
# x1 += -min(0, x1)
#
# finally:
# return img[x1:x2, y1:y2, ...].copy()
return img[x1:x2, y1:y2, ...].copy()
def center_crop(img, output_size):
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
h, w, _ = img.shape
th, tw = output_size
i = int(round((h - th) * 0.5))
j = int(round((w - tw) * 0.5))
return crop(img, i, j, th, tw)
def resized_crop(img, i, j, h, w, size, interpolation="BILINEAR"):
"""Crop the given CV Image and resize it to desired size. Notably used in RandomResizedCrop.
Args:
img (np.ndarray): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
size (sequence or int): Desired output size. Same semantics as ``scale``.
interpolation (str, optional): Desired interpolation. Default is
``BILINEAR``.
Returns:
np.ndarray: Cropped image.
"""
assert _is_numpy_image(img), "img should be CV Image"
img = crop(img, i, j, h, w)
img = resize(img, size, interpolation)
return img
def hflip(img):
"""Horizontally flip the given PIL Image.
Args:
img (np.ndarray): Image to be flipped.
Returns:
np.ndarray: Horizontall flipped image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
return cv2.flip(img, 1)
def vflip(img):
"""Vertically flip the given PIL Image.
Args:
img (CV Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return cv2.flip(img, 0)
def five_crop(img, size):
"""Crop the given CV Image into four corners and the central crop.
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
Returns:
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
top right, bottom left, bottom right and center crop.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
h, w, _ = img.shape
crop_h, crop_w = size
if crop_w > w or crop_h > h:
raise ValueError(
"Requested crop size {} is bigger than input size {}".format(size, (h, w))
)
tl = crop(img, 0, 0, crop_h, crop_w)
tr = crop(img, 0, w - crop_w, crop_h, crop_w)
bl = crop(img, h - crop_h, 0, crop_h, crop_w)
br = crop(img, h - crop_h, w - crop_w, crop_h, crop_w)
center = center_crop(img, (crop_h, crop_w))
return (tl, tr, bl, br, center)
def ten_crop(img, size, vertical_flip=False):
"""Crop the given CV Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
br_flip, center_flip) corresponding top left, top right,
bottom left, bottom right and center crop and same for the
flipped image.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.
Args:
img (np.ndarray): CV Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
np.ndarray: Brightness adjusted image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
im = img.astype(np.float32) * brightness_factor
im = im.clip(min=0, max=255)
return im.astype(img.dtype)
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (np.ndarray): CV Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
np.ndarray: Contrast adjusted image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
im = img.astype(np.float32)
mean = round(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY).mean())
im = (1 - contrast_factor) * mean + contrast_factor * im
im = im.clip(min=0, max=255)
return im.astype(img.dtype)
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (np.ndarray): CV Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a gray image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
np.ndarray: Saturation adjusted image.
"""
if not _is_numpy_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
im = img.astype(np.float32)
degenerate = cv2.cvtColor(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
im = (1 - saturation_factor) * degenerate + saturation_factor * im
im = im.clip(min=0, max=255)
return im.astype(img.dtype)
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (np.ndarray): CV Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
np.ndarray: Hue adjusted image.
"""
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError("hue_factor is not in [-0.5, 0.5].")
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
im = img.astype(np.uint8)
hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV_FULL)
hsv[..., 0] += np.uint8(hue_factor * 255)
im = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB_FULL)
return im.astype(img.dtype)
def adjust_gamma(img, gamma, gain=1):
"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
I_out = 255 * gain * ((I_in / 255) ** gamma)
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
Args:
img (np.ndarray): CV Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions
lighter.
gain (float): The constant multiplier.
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
im = img.astype(np.float32)
im = 255.0 * gain * np.power(im / 255.0, gamma)
im = im.clip(min=0.0, max=255.0)
return im.astype(img.dtype)
def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.
Args:
img (np.ndarray): Image to be converted to grayscale.
Returns:
CV Image: Grayscale version of the image.
if num_output_channels == 1 : returned image is single channel
if num_output_channels == 3 : returned image is 3 channel with r == g == b
"""
if not _is_numpy_image(img):
raise TypeError("img should be CV Image. Got {}".format(type(img)))
if num_output_channels == 1:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif num_output_channels == 3:
img = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
def gaussian_noise(img: np.ndarray, mean, std):
imgtype = img.dtype
gauss = np.random.normal(mean, std, img.shape).astype(np.float32)
noisy = np.clip((1 + gauss) * img.astype(np.float32), 0, 255)
return noisy.astype(imgtype)
def poisson_noise(img):
imgtype = img.dtype
img = img.astype(np.float32) / 255.0
vals = len(np.unique(img))
vals = 2 ** np.ceil(np.log2(vals))
noisy = 255 * np.clip(
np.random.poisson(img.astype(np.float32) * vals) / float(vals), 0, 1
)
return noisy.astype(imgtype)
def salt_and_pepper(img, prob=0.01):
"""Adds "Salt & Pepper" noise to an image.
prob: probability (threshold) that controls level of noise
"""
imgtype = img.dtype
rnd = np.random.rand(img.shape[0], img.shape[1])
noisy = img.copy()
noisy[rnd < prob / 2] = 0.0
noisy[rnd > 1 - prob / 2] = 255.0
return noisy.astype(imgtype)
def cv_transform(img):
img = salt_and_pepper(img)
return to_tensor(img)
def pil_transform(img):
return functional.to_tensor(img)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,596 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import copy
import logging
import re
import numpy as np
import docling_ibm_models.tableformer.otsl as otsl
import docling_ibm_models.tableformer.settings as s
# LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
LOG_LEVEL = logging.WARN
# Cell labels
BODY = "body"
COL_HEADER = "col_header"
MULTI_COL_HEADER = "multi_col_header"
MULTI_ROW_HEADER = "multi_row_header"
MULTI_ROW = "multi_row"
MULTI_COL = "multi_col"
def validate_bboxes_page(bboxes):
r"""
Useful function for Debugging
Validate that the bboxes have a positive area in the page coordinate system
Parameters
----------
bboxes : list of 4
Each element of the list is expected to be a bbox in the page coordinates system
Returns
-------
int
The number of invalid bboxes.
"""
invalid_counter = 0
for i, bbox in enumerate(bboxes):
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if area < 0:
print("Wrong bbox: {} - {}".format(i, bbox))
invalid_counter += 1
if invalid_counter > 0:
print("Invalid bboxes in total: {}".format(invalid_counter))
return invalid_counter
def find_intersection(b1, b2):
r"""
Compute the intersection between 2 bboxes
Parameters
----------
b1 : list of 4
The page x1y1x2y2 coordinates of the bbox
b2 : list of 4
The page x1y1x2y2 coordinates of the bbox
Returns
-------
The bbox of the intersection or None if there is no intersection
"""
# Check when the bboxes do NOT intersect
if b1[2] < b2[0] or b2[2] < b1[0] or b1[1] > b2[3] or b2[1] > b2[3]:
return None
i_bbox = [
max(b1[0], b2[0]),
max(b1[1], b2[1]),
min(b1[2], b2[2]),
min(b1[3], b2[3]),
]
return i_bbox
class CellMatcher:
r"""
Match the table cells to the pdf page cells.
NOTICE: PDF page coordinate system vs table coordinate system.
In both systems the bboxes are described in as (x1, y1, x2, y2) with the following meaning:
Page coordinate system:
- Origin (0, 0) at the lower-left corner
- (x1, y1) the lower left corner of the box
- (x2, y2) the upper right corner of the box
Table coordinate system:
- Origin (0, 0) at the upper-left corner
- (x1, y1) the upper left corner of the box
- (x2, y2) the lower right corner of the box
"""
def __init__(self, config):
self._config = config
self._iou_thres = config["predict"]["pdf_cell_iou_thres"]
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def match_cells(self, iocr_page, table_bbox, prediction):
r"""
Convert the tablemodel prediction into the Docling format
Parameters
----------
iocr_page : dict
The original Docling provided table data
prediction : dict
The dictionary has the keys:
"tag_seq": The sequence in indices from the WORDMAP
"html_seq": The sequence as html tags
"bboxes": The bounding boxes
Returns
-------
matching_details : dict
Dictionary with all details about the mathings between the table and pdf cells
"""
pdf_cells = copy.deepcopy(iocr_page["tokens"])
for word in pdf_cells:
word["bbox"] = [
word["bbox"]["l"],
word["bbox"]["t"],
word["bbox"]["r"],
word["bbox"]["b"],
]
table_bboxes = prediction["bboxes"]
table_classes = prediction["classes"]
# BBOXES transformed...
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
# Combine the table tags and bboxes into TableCells
html_seq = prediction["html_seq"]
otsl_seq = prediction["rs_seq"]
table_cells = self._build_table_cells(
html_seq, otsl_seq, table_bboxes_page, table_classes
)
matches, matches_counter = self._intersection_over_pdf_match(
table_cells, pdf_cells
)
self._log().debug("matches_counter: {}".format(matches_counter))
# Build output
matching_details = {
"iou_threshold": self._iou_thres,
"table_bbox": table_bbox,
"prediction_bboxes_page": table_bboxes_page, # Make easier the comparison with c++
"prediction": prediction,
"pdf_cells": pdf_cells,
"page_height": iocr_page["height"],
"page_width": iocr_page["width"],
"table_cells": table_cells,
"pdf_cells": pdf_cells,
"matches": matches,
}
return matching_details
def match_cells_dummy(self, iocr_page, table_bbox, prediction):
r"""
Convert the tablemodel prediction into the Docling format
DUMMY version doesn't do matching with text cells, but propagates predicted bboxes,
respecting the rest of the format
Parameters
----------
iocr_page : dict
The original Docling provided table data
prediction : dict
The dictionary has the keys:
"tag_seq": The sequence in indices from the WORDMAP
"html_seq": The sequence as html tags
"bboxes": The bounding boxes
Returns
-------
matching_details : dict
Dictionary with all details about the mathings between the table and pdf cells
"""
pdf_cells = copy.deepcopy(iocr_page["tokens"])
for word in pdf_cells:
word["bbox"] = [
word["bbox"]["l"],
word["bbox"]["t"],
word["bbox"]["r"],
word["bbox"]["b"],
]
table_bboxes = prediction["bboxes"]
table_classes = prediction["classes"]
# BBOXES transformed...
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
# Combine the table tags and bboxes into TableCells
html_seq = prediction["html_seq"]
otsl_seq = prediction["rs_seq"]
table_cells = self._build_table_cells(
html_seq, otsl_seq, table_bboxes_page, table_classes
)
# Build output
matching_details = {
"iou_threshold": self._iou_thres,
"table_bbox": table_bbox,
"prediction_bboxes_page": table_bboxes_page,
"prediction": prediction,
"pdf_cells": pdf_cells,
"page_height": iocr_page["height"],
"page_width": iocr_page["width"],
"table_cells": table_cells,
"pdf_cells": pdf_cells,
"matches": {},
}
return matching_details
def _build_table_cells(self, html_seq, otsl_seq, bboxes, table_classes):
r"""
Combine the tags and bboxes of the table into unified TableCell objects.
Each TableCell takes a row_id, column_id index based on the html structure provided by
html_seq.
It is assumed that the bboxes are in sync with the appearence of the closing </td>
Parameters
----------
html_seq : list
List of html tags
bboxes : list of lists of 4
Bboxes for the table cells at the page origin
Returns
-------
list of dict
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
"""
table_html_structure = {
"html": {"structure": {"tokens": html_seq}},
"split": "predict",
"filename": "memory",
}
otsl_spans = {}
# r, o = otsl.html_to_otsl(table, writer, true, extra_debug, include_html)
r, o = otsl.html_to_otsl(table_html_structure, None, False, False, True, False)
if not r:
ermsg = "ERR#: COULD NOT CONVERT TO RS THIS TABLE TO COMPUTE SPANS"
print(ermsg)
else:
otsl_spans = o["otsl_spans"]
table_cells = []
# It is assumed that the bboxes appear in sync (at the same order) as the TDs
cell_id = 0
row_id = -1
column_id = -1
in_header = False
in_body = False
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
mode = "OTSL"
if mode == "HTML":
for tag in html_seq:
label = None
if tag == "<thead>":
in_header = True
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
elif tag == "</thead>":
in_header = False
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
elif tag == "<tbody>":
in_body = True
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
elif tag == "</tbody>":
in_body = False
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
elif tag == "<td>" or tag == "<td":
column_id += 1
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
if tag == "<td":
multicol_tag = tag
elif tag == "<tr>":
row_id += 1
column_id = -1
multicol_tag = ""
colspan_val = 0
rowspan_val = 0
elif "colspan" in tag:
label = MULTI_COL
multicol_tag += tag
colspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
elif "rowspan" in tag:
label = MULTI_ROW
multicol_tag += tag
rowspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
elif tag == "</td>": # Create a TableCell on each closing td
if len(multicol_tag) > 0:
multicol_tag += tag
if in_header:
if label is None:
label = COL_HEADER
elif label == MULTI_COL:
label = MULTI_COL_HEADER
elif label == MULTI_ROW:
label = MULTI_ROW_HEADER
if label is None and in_body:
label = BODY
err_mismatch = "Mismatching bboxes with closing TDs {} < {}".format(
cell_id, len(bboxes)
)
assert cell_id < len(bboxes), err_mismatch
bbox = bboxes[cell_id]
cell_class = table_classes[cell_id]
table_cell = {}
table_cell["cell_id"] = cell_id
table_cell["row_id"] = row_id
table_cell["column_id"] = column_id
table_cell["bbox"] = bbox
table_cell["cell_class"] = cell_class
table_cell["label"] = label
table_cell["multicol_tag"] = multicol_tag
if colspan_val > 0:
table_cell["colspan_val"] = colspan_val
column_id += (
colspan_val - 1
) # Shift column index to account for span
if rowspan_val > 0:
table_cell["rowspan_val"] = rowspan_val
table_cells.append(table_cell)
cell_id += 1
if mode == "OTSL":
row_id = 0
column_id = 0
multicol_tag = ""
otsl_line = []
cell_id_line = []
for tag in otsl_seq:
otsl_line.append(tag)
if tag == "nl":
row_id += 1
column_id = 0
otsl_line = []
cell_id_line = []
if tag in ["fcel", "ecel", "xcel", "ched", "rhed", "srow"]:
cell_id_line.append(cell_id)
bbox = [0.0, 0.0, 0.0, 0.0]
if cell_id < len(bboxes):
bbox = bboxes[cell_id]
cell_class = 2
if cell_id < len(table_classes):
cell_class = table_classes[cell_id]
label = tag
table_cell = {}
table_cell["cell_id"] = cell_id
table_cell["row_id"] = row_id
table_cell["column_id"] = column_id
table_cell["bbox"] = bbox
table_cell["cell_class"] = cell_class
table_cell["label"] = label
table_cell["multicol_tag"] = multicol_tag
colspan_val = 0
rowspan_val = 0
if cell_id in otsl_spans:
colspan_val = otsl_spans[cell_id][0]
rowspan_val = otsl_spans[cell_id][1]
if colspan_val > 0:
table_cell["colspan_val"] = colspan_val
if rowspan_val > 0:
table_cell["rowspan_val"] = rowspan_val
table_cells.append(table_cell)
cell_id += 1
if tag != "nl":
column_id += 1
return table_cells
def _translate_bboxes(self, table_bbox, cell_bboxes):
r"""
Translate table cell bboxes to the lower-left corner of the page.
The cells of the table are given:
- Origin at the top left corner
- Point A: Top left corner
- Point B: Low right corner
- Coordinate values are normalized to the table width/height
Parameters
----------
table_bbox : list of 4
The whole table bbox page coordinates
cell_bboxes : list of lists of 4
The bboxes of the table cells
Returns
-------
list of 4
The translated bboxes of the table cells
"""
W = table_bbox[2] - table_bbox[0]
H = table_bbox[3] - table_bbox[1]
b = np.asarray(cell_bboxes)
t_mask = np.asarray(
[table_bbox[0], table_bbox[3], table_bbox[0], table_bbox[3]]
)
m = np.asarray([W, -H, W, -H])
page_bboxes_y_flipped = t_mask + m * b
page_bboxes = page_bboxes_y_flipped[:, [0, 3, 2, 1]] # Flip y1' with y2'
page_bboxes_list = page_bboxes.tolist()
t_height = table_bbox[3]
page_bboxes_list1 = []
for page_bbox in page_bboxes_list:
page_bbox1 = [
page_bbox[0],
t_height - page_bbox[3] + table_bbox[1],
page_bbox[2],
t_height - page_bbox[1] + table_bbox[1],
]
page_bboxes_list1.append(page_bbox1)
return page_bboxes_list1
def _intersection_over_pdf_match(self, table_cells, pdf_cells):
r"""
Compute Intersection between table cells and pdf cells,
match 1 pdf cell with highest intersection with only 1 table cell.
First compute and cache the areas for all involved bboxes.
Then compute the pairwise intersections
Parameters
----------
table_cells : list of dict
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
pdf_cells : list of dict
Each element of the list is a dictionary which should have the keys: "id", "bbox"
Returns
-------
dictionary of lists of table_cells
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
of the table_cells that fall inside that pdf cell
int
Number of total matches
"""
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
)
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
matches = {}
matches_counter = 0
# Compute Intersections and build matches
for i, table_cell in enumerate(table_cells):
table_cell_id = table_cell["cell_id"]
t_bbox = table_cell["bbox"]
for j, pdf_cell in enumerate(pdf_cells):
pdf_cell_id = pdf_cell["id"]
p_bbox = pdf_cell["bbox"]
# Compute intersection
i_bbox = find_intersection(t_bbox, p_bbox)
if i_bbox is None:
continue
# Compute IOU and filter on threshold
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
iopdf = 0
if float(pdf_bboxes_areas[j]) > 0:
iopdf = i_bbox_area / float(pdf_bboxes_areas[j])
if iopdf > 0:
match = {"table_cell_id": table_cell_id, "iopdf": iopdf}
if pdf_cell_id not in matches:
matches[pdf_cell_id] = [match]
matches_counter += 1
else:
# Check if the same match was not already counted
if match not in matches[pdf_cell_id]:
matches[pdf_cell_id].append(match)
matches_counter += 1
return matches, matches_counter
def _iou_match(self, table_cells, pdf_cells):
r"""
Use Intersection over Union to decide the matching between table cells and pdf cells
First compute and cache the areas for all involved bboxes.
Then compute the pairwise intersections and IOUs and keep those pairs that exceed the IOU
threshold
Parameters
----------
table_cells : list of dict
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
pdf_cells : list of dict
Each element of the list is a dictionary which should have the keys: "id", "bbox"
Returns
-------
dictionary of lists of table_cells
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
of the table_cells that fall inside that pdf cell
int
Number of total matches
"""
table_bboxes = np.asarray([t["bbox"] for t in table_cells])
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
# Cache the areas for table bboxes and pdf bboxes
table_bboxes_areas = (table_bboxes[:, 2] - table_bboxes[:, 0]) * (
table_bboxes[:, 3] - table_bboxes[:, 1]
)
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
)
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
matches = {}
matches_counter = 0
# Compute IOUs and build matches
for i, table_cell in enumerate(table_cells):
table_cell_id = table_cell["cell_id"]
t_bbox = table_cell["bbox"]
for j, pdf_cell in enumerate(pdf_cells):
pdf_cell_id = pdf_cell["id"]
pdf_cell_text = pdf_cell["text"]
p_bbox = pdf_cell["bbox"]
# Compute intersection
i_bbox = find_intersection(t_bbox, p_bbox)
if i_bbox is None:
continue
# Compute IOU and filter on threshold
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
iou = 0
div_area = float(
table_bboxes_areas[i] + pdf_bboxes_areas[j] - i_bbox_area
)
if div_area > 0:
iou = i_bbox_area / div_area
if iou < self._iou_thres:
continue
if pdf_cell_id not in matches:
matches[pdf_cell_id] = []
match = {
"table_cell_id": table_cell_id,
"iou": iou,
"text": pdf_cell_text,
}
matches[pdf_cell_id].append(match)
matches_counter += 1
return matches, matches_counter
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,396 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
from __future__ import division
import collections
import numbers
import random
import torch
from docling_ibm_models.tableformer.data_management import functional as F
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
class Lambda(object):
"""Apply a user-defined lambda as a transform.
Attention: The multiprocessing used in dataloader of pytorch
is not friendly with lambda function in Windows
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
# assert isinstance(lambd, types.LambdaType)
self.lambd = lambd
# if 'Windows' in platform.system():
# raise RuntimeError("Can't pickle lambda funciton in windows system")
def __call__(self, img):
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + "()"
class RandomTransforms(object):
"""Base class for a list of transformations with randomness
Args:
transforms (list or tuple): list of transformations
"""
def __init__(self, transforms):
assert isinstance(transforms, (list, tuple))
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list"""
def __call__(self, img, target):
t = random.choice(self.transforms)
return t(img, target)
class RandomCrop(object):
def __init__(self, size, margin_crop):
self.size = list(size)
self.margin_crop = list(margin_crop)
# margin_crop: w, h
def __call__(self, img, target):
# img (w,h,ch)
image_height, image_width = img.shape[0], img.shape[1]
"""
img (np.ndarray): Image to be cropped.
x: Upper pixel coordinate.
y: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
"""
if image_width > 0 and image_height > 0:
cropped_image = F.crop(
img,
self.margin_crop[1],
self.margin_crop[0],
image_height - (self.margin_crop[1] * 2),
image_width - (self.margin_crop[0] * 2),
)
target_ = target.copy()
target_["boxes"][:, 0] = target_["boxes"][:, 0] - self.margin_crop[0]
target_["boxes"][:, 1] = target_["boxes"][:, 1] - self.margin_crop[1]
target_["boxes"][:, 2] = target_["boxes"][:, 2] - self.margin_crop[0]
target_["boxes"][:, 3] = target_["boxes"][:, 3] - self.margin_crop[1]
else:
cropped_image = img
return cropped_image, target_
class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
pad_x1 = random.randint(0, self.max_pad)
pad_y1 = random.randint(0, self.max_pad)
img = img.copy()
padded_image = F.pad(img, (pad_x, pad_y, pad_x1, pad_y1), fill=(255, 255, 255))
target_ = target.copy()
if target["boxes"] is not None:
target_["boxes"][:, 0] = target_["boxes"][:, 0] + pad_x
target_["boxes"][:, 1] = target_["boxes"][:, 1] + pad_y
target_["boxes"][:, 2] = target_["boxes"][:, 2] + pad_x
target_["boxes"][:, 3] = target_["boxes"][:, 3] + pad_y
return padded_image, target_
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
assert isinstance(brightness, float) or (
isinstance(brightness, collections.Iterable) and len(brightness) == 2
)
assert isinstance(contrast, float) or (
isinstance(contrast, collections.Iterable) and len(contrast) == 2
)
assert isinstance(saturation, float) or (
isinstance(saturation, collections.Iterable) and len(saturation) == 2
)
assert isinstance(hue, float) or (
isinstance(hue, collections.Iterable) and len(hue) == 2
)
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if isinstance(brightness, numbers.Number):
if brightness > 0:
brightness_factor = random.uniform(
max(0, 1 - brightness), 1 + brightness
)
transforms.append(
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
)
if contrast > 0:
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
)
if saturation > 0:
saturation_factor = random.uniform(
max(0, 1 - saturation), 1 + saturation
)
transforms.append(
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
)
if hue > 0:
hue_factor = random.uniform(-hue, hue)
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
else:
if brightness[0] > 0 and brightness[1] > 0:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
)
if contrast[0] > 0 and contrast[1] > 0:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
)
if saturation[0] > 0 and saturation[1] > 0:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
)
if hue[0] > 0 and hue[1] > 0:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms)
transform = ComposeSingle(transforms)
return transform
def __call__(self, img, target):
"""
Args:
img (np.ndarray): Input image.
Returns:
np.ndarray: Color jittered image.
"""
transform = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
return transform(img), target
def __repr__(self):
format_string = self.__class__.__name__ + "("
format_string += "brightness={0}".format(self.brightness)
format_string += ", contrast={0}".format(self.contrast)
format_string += ", saturation={0}".format(self.saturation)
format_string += ", hue={0})".format(self.hue)
return format_string
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor, target=None):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std), target
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class NoTransformation(object):
"""Do Nothing"""
def __call__(self, img, target):
return img, target
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, target):
for t in self.transforms:
img, target = t(img, target)
return img, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class ComposeSingle(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``BILINEAR``
"""
def __init__(self, size, interpolation="BILINEAR"):
self.size = size
self.interpolation = interpolation
def __call__(self, img, target=None):
"""
Args:
img (np.ndarray): Image to be scaled.
Returns:
np.ndarray: Rescaled image.
"""
# Resize bboxes (in pixels)
x_scale = 0
y_scale = 0
if img.shape[1] > 0:
x_scale = self.size[0] / img.shape[1]
if img.shape[0] > 0:
y_scale = self.size[1] / img.shape[0]
# loop over bboxes
if target is not None:
if target["boxes"] is not None:
target_ = target.copy()
target_["boxes"][:, 0] = x_scale * target_["boxes"][:, 0]
target_["boxes"][:, 1] = y_scale * target_["boxes"][:, 1]
target_["boxes"][:, 2] = x_scale * target_["boxes"][:, 2]
target_["boxes"][:, 3] = y_scale * target_["boxes"][:, 3]
return F.resize(img, self.size, self.interpolation), target
def __repr__(self):
interpolate_str = self.interpolation
return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
self.size, interpolate_str
)
@@ -0,0 +1,279 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import glob
import logging
import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
import torch
import docling_ibm_models.tableformer.settings as s
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
class BaseModel(ABC):
r"""
BaseModel provides some common functionality for all models:
- Saves checkpoint files for each epoch
- Loads the model from the best available checkpoint
- Save repository branch and commit
"""
def __init__(self, config, init_data, device):
r"""
Inputs:
config: The configuration file
init_data: Dictionary with initialization data. This dictionary can be used to pass any
kind of initialization data for the models
device: The device used to move the tensors of the model
"""
super(BaseModel, self).__init__()
# Set config and device
self._config = config
self._init_data = init_data
self._device = device
self._save_dir = config["model"]["save_dir"]
self._load_checkpoint = None
if "load_checkpoint" in config["model"]:
self._load_checkpoint = config["model"]["load_checkpoint"]
self._branch_name = "dev/next"
self._commit_sha = "1"
# Keep a dictionary with the starting times per epoch.
# NOTICE: Epochs start from 0
self._epoch_start_ts = {0: time.time()}
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
@abstractmethod
def predict(self, img, max_steps, beam_size, return_attention):
pass
def count_parameters(self):
r"""Counts the number of trainable parameters of this model
Output:
num_parameters: number of trainable parameters
"""
num_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
return num_parameters
def get_code_version(self):
r"""Gets the source control version of this model code
Returns
-------
branch_name : str
The name of the Git branch of this model code
commit_sha : str
The unique identifier of the Git commit of this model code
"""
return self._branch_name, self._commit_sha
def get_save_directory(self):
r"""
Return the save directory
"""
return self._save_dir
def is_saved(self):
r"""
This method returns True if both conditions are met:
1. There is a checkpoint file for the model.
2. The checkpoint file corresponds to the last training epoch set in the configuration file.
"""
# Get the saved_model
saved_model, _ = self._load_best_checkpoint()
if saved_model is None:
return False
epochs = self._config["train"]["epochs"]
self._log().debug(
"Best epoch in saved model: {}; Number of epochs in config: {}".format(
saved_model["epoch"], epochs
)
)
if epochs == saved_model["epoch"] + 1:
return True
return False
def save(self, epoch=None, optimizers=None, losses=None, model_parameters=None):
r"""
Save the model data to the disk as a pickle file.
Parameters
----------
epoch: Training epoch
optimizers: Dictionary with the optimizers. The key specifies what the optimizer is
used for. The 'state_dict' of each optimizer will be saved in the
checkpoint file.
losses: Dictionary with the losses. The key specifies what the loss is used for. Each
value is a list
model_parameters: Dictionary with model specific parameters that we need to save in the
checkpoint file.
Returns
-------
True if success, False otherwise
"""
# Get the checkpoint_filename
c_filename = self._build_checkpoint_filename(epoch)
self._log().debug("Trying to save checkpoint file: {}".format(c_filename))
# Prepare a dictionary with all data we want to save
optimizers_state_dict = None
if optimizers is not None:
optimizers_state_dict = {k: v.state_dict() for k, v in optimizers.items()}
model_data = {
"model_state_dict": self.state_dict(),
"epoch": epoch,
"optimizers": optimizers_state_dict,
"losses": losses,
"model_parameters": model_parameters,
}
# Add the processing time per epoch
now = time.time()
self._epoch_start_ts[epoch + 1] = now
if epoch in self._epoch_start_ts:
dt = now - self._epoch_start_ts[epoch]
model_data["epoch_start_ts"] = self._epoch_start_ts[epoch]
model_data["epoch_dt"] = dt
# Create the save directory
Path(self._save_dir).mkdir(parents=True, exist_ok=True)
# Save the model
torch.save(model_data, c_filename)
# Return true if file is present, otherwise false
if not os.path.isfile(c_filename):
self._log().error("Cannot find the file to save: " + c_filename)
return False
# store code branch name and commit
version_file = os.path.join(self._save_dir, "_version")
with open(version_file, "w") as text_file:
print("Model is using code [commit:branch]", file=text_file)
print("{}:{}".format(self._commit_sha, self._branch_name), file=text_file)
return True
def load(self, optimizers=None):
r"""
Load the model data from the disk.
The method will iterate over all *.check files and try to load the one from the highest
epoch.
Input:
-optimizers: Dictionary with optimizers. If it is not null the keys will be used to
associate the corresponding state_dicts from the checkpoint file and update
the internal states of the provided optimizers.
Output:
- Success: True/ False
- epoch: Loaded epoch or -1 if there are no checkpoint files
- optimizers: Dictionary with loaded optimizers or empty dictionary of there is no
checkpoint file
- losses: Dictionary with loaded losses or empty dictionary of there is no checkpoint
file
- model_parameters: Dictionary with the model parameters or empty dictionary if there
are no checkpoint files
"""
# Get the saved_model
saved_model, _ = self._load_best_checkpoint()
# Restore the model
if saved_model is None:
self._log().debug("No saved model checkpoint found")
return False, -1, optimizers, {}, {}
self._log().debug("Loading model from checkpoint file")
self.load_state_dict(saved_model["model_state_dict"])
epoch = 0
if "epoch" in saved_model:
epoch = saved_model["epoch"]
losses = {}
if "losses" in saved_model:
losses = saved_model["losses"]
model_parameters = saved_model["model_parameters"]
if optimizers is not None:
for key, optimizer_state_dict in saved_model["optimizers"].items():
optimizers[key].load_state_dict(optimizer_state_dict)
# Reset the start_ts of the next epoch
self._epoch_start_ts[epoch + 1] = time.time()
return True, epoch, optimizers, losses, model_parameters
def _load_best_checkpoint(self):
r"""
If a "load_checkpoint" file has been provided, load this one.
Otherwise use the "save_dir" and load the one with the most advanced epoch
Returns
-------
saved_model : dictionary
Checkpoint file contents generated by torch.load, or None
checkpoint_file : string
Filename of the loaded checkpoint, or None
"""
checkpoint_files = []
# If a "load_checkpoint" file is provided, try to load it
if self._load_checkpoint is not None:
if not os.path.exists(self._load_checkpoint):
self._log().error(
"Cannot load the checkpoint: {}".format(self._load_checkpoint)
)
return None, None
checkpoint_files.append(self._load_checkpoint)
else:
# Iterate over all check files from the directory by reverse alphabetical order
# This will get the biggest epoch first
checkpoint_files = glob.glob(os.path.join(self._save_dir, "*.check"))
checkpoint_files.sort(reverse=True)
for checkpoint_file in checkpoint_files:
try:
# Try to load the file
self._log().info(
"Loading model checkpoint file: {}".format(checkpoint_file)
)
saved_model = torch.load(checkpoint_file, map_location=self._device)
return saved_model, checkpoint_file
except RuntimeError:
self._log().error("Cannot load file: {}".format(checkpoint_file))
return None, None
def _build_checkpoint_filename(self, epoch):
r"""
Construct the full path for the filename of this checkpoint
"""
dataset_name = self._config["dataset"]["name"]
model_type = self._config["model"]["type"]
model_name = self._config["model"]["name"]
filename = "{}_{}_{}_{:03}.check".format(
model_type, model_name, dataset_name, epoch
)
c_filename = os.path.join(self._save_dir, filename)
return c_filename
@@ -0,0 +1,163 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import torch
import torch.nn as nn
import docling_ibm_models.tableformer.settings as s
import docling_ibm_models.tableformer.utils.utils as u
# from scipy.optimize import linear_sum_assignment
LOG_LEVEL = logging.INFO
class CellAttention(nn.Module):
"""
Attention Network.
"""
def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim):
"""
:param encoder_dim: feature size of encoded images
:param tag_decoder_dim: size of tag decoder's RNN
:param language_dim: size of language model's RNN
:param attention_dim: size of the attention network
"""
super(CellAttention, self).__init__()
# linear layer to transform encoded image
self._encoder_att = nn.Linear(encoder_dim, attention_dim)
# linear layer to transform tag decoder output
self._tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim)
# linear layer to transform language models output
self._language_att = nn.Linear(language_dim, attention_dim)
# linear layer to calculate values to be softmax-ed
self._full_att = nn.Linear(attention_dim, 1)
self._relu = nn.ReLU()
self._softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def forward(self, encoder_out, decoder_hidden, language_out):
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim)
:param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells,
tag_decoder_dim)]
:param language_out: language model output, a tensor of dimension (num_cells,
language_dim)
:return: attention weighted encoding, weights
"""
att1 = self._encoder_att(encoder_out) # (1, num_pixels, attention_dim)
att2 = self._tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim)
att3 = self._language_att(language_out) # (num_cells, attention_dim)
att = self._full_att(
self._relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))
).squeeze(2)
alpha = self._softmax(att) # (num_cells, num_pixels)
# (num_cells, encoder_dim)
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
return attention_weighted_encoding, alpha
class BBoxDecoder(nn.Module):
"""
CellDecoder generates cell content
"""
def __init__(
self,
device,
attention_dim,
embed_dim,
tag_decoder_dim,
decoder_dim,
num_classes,
encoder_dim=512,
dropout=0.5,
cnn_layer_stride=1,
):
"""
:param attention_dim: size of attention network
:param embed_dim: embedding size
:param tag_decoder_dim: size of tag decoder's RNN
:param decoder_dim: size of decoder's RNN
:param vocab_size: size of vocabulary
:param encoder_dim: feature size of encoded images
:param dropout: dropout
:param mini_batch_size: batch size of cells to reduce GPU memory usage
"""
super(BBoxDecoder, self).__init__()
self._device = device
self._encoder_dim = encoder_dim
self._attention_dim = attention_dim
self._embed_dim = embed_dim
self._decoder_dim = decoder_dim
self._dropout = dropout
self._num_classes = num_classes
if cnn_layer_stride is not None:
self._input_filter = u.resnet_block(stride=cnn_layer_stride)
# attention network
self._attention = CellAttention(
encoder_dim, tag_decoder_dim, decoder_dim, attention_dim
)
# decoder LSTMCell
self._init_h = nn.Linear(encoder_dim, decoder_dim)
# linear layer to create a sigmoid-activated gate
self._f_beta = nn.Linear(decoder_dim, encoder_dim)
self._sigmoid = nn.Sigmoid()
self._dropout = nn.Dropout(p=self._dropout)
self._class_embed = nn.Linear(512, self._num_classes + 1)
self._bbox_embed = u.MLP(512, 256, 4, 3)
def _init_hidden_state(self, encoder_out, batch_size):
mean_encoder_out = encoder_out.mean(dim=1)
h = self._init_h(mean_encoder_out).expand(batch_size, -1)
return h
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def inference(self, encoder_out, tag_H):
"""
Inference on test images with beam search
"""
if hasattr(self, "_input_filter"):
encoder_out = self._input_filter(encoder_out.permute(0, 3, 1, 2)).permute(
0, 2, 3, 1
)
encoder_dim = encoder_out.size(3)
# Flatten encoding (1, num_pixels, encoder_dim)
encoder_out = encoder_out.view(1, -1, encoder_dim)
num_cells = len(tag_H)
predictions_bboxes = []
predictions_classes = []
for c_id in range(num_cells):
# Start decoding
h = self._init_hidden_state(encoder_out, 1)
cell_tag_H = tag_H[c_id]
awe, _ = self._attention(encoder_out, cell_tag_H, h)
gate = self._sigmoid(self._f_beta(h))
awe = gate * awe
h = awe * h
predictions_bboxes.append(self._bbox_embed(h).sigmoid())
predictions_classes.append(self._class_embed(h))
if len(predictions_bboxes) > 0:
predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes])
if len(predictions_classes) > 0:
predictions_classes = torch.stack([x[0] for x in predictions_classes])
return predictions_classes, predictions_bboxes
@@ -0,0 +1,72 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import torch.nn as nn
import torchvision
import docling_ibm_models.tableformer.settings as s
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
class Encoder04(nn.Module):
"""
Encoder based on resnet-18
"""
def __init__(self, enc_image_size, enc_dim=512):
r"""
Parameters
----------
enc_image_size : int
Assuming that the encoded image is a square, this is the length of the image side
"""
super(Encoder04, self).__init__()
self.enc_image_size = enc_image_size
self._encoder_dim = enc_dim
resnet = torchvision.models.resnet18(pretrained=False)
modules = list(resnet.children())[:-3]
self._resnet = nn.Sequential(*modules)
self._adaptive_pool = nn.AdaptiveAvgPool2d(
(self.enc_image_size, self.enc_image_size)
)
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def get_encoder_dim(self):
return self._encoder_dim
def forward(self, images):
"""
Forward propagation
The encoder_dim 512 is decided by the structure of the image network (modified resnet-19)
Parameters
----------
images : tensor (batch_size, image_channels, resized_image, resized_image)
images input
Returns
-------
tensor : (batch_size, enc_image_size, enc_image_size, 256)
encoded images
"""
out = self._resnet(images) # (batch_size, 256, 28, 28)
self._log().debug("forward: resnet out: {}".format(out.size()))
out = self._adaptive_pool(out)
out = out.permute(
0, 2, 3, 1
) # (batch_size, enc_image_size, enc_image_size, 256)
self._log().debug("enc forward: final out: {}".format(out.size()))
return out
@@ -0,0 +1,324 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import torch
import torch.nn as nn
import docling_ibm_models.tableformer.settings as s
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
from docling_ibm_models.tableformer.models.table04_rs.bbox_decoder_rs import BBoxDecoder
from docling_ibm_models.tableformer.models.table04_rs.encoder04_rs import Encoder04
from docling_ibm_models.tableformer.models.table04_rs.transformer_rs import (
Tag_Transformer,
)
from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
LOG_LEVEL = logging.WARN
# LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
class TableModel04_rs(BaseModel, nn.Module):
r"""
TableNet04Model encoder, dual-decoder model with OTSL+ support
"""
def __init__(self, config, init_data, purpose, device):
super(TableModel04_rs, self).__init__(config, init_data, device)
self._prof = config["predict"].get("profiling", False)
self._device = device
# Extract the word_map from the init_data
word_map = init_data["word_map"]
# Encoder
self._enc_image_size = config["model"]["enc_image_size"]
self._encoder_dim = config["model"]["hidden_dim"]
self._encoder = Encoder04(self._enc_image_size, self._encoder_dim).to(device)
tag_vocab_size = len(word_map["word_map_tag"])
td_encode = []
for t in ["ecel", "fcel", "ched", "rhed", "srow"]:
if t in word_map["word_map_tag"]:
td_encode.append(word_map["word_map_tag"][t])
self._log().debug("td_encode length: {}".format(len(td_encode)))
self._log().debug("td_encode: {}".format(td_encode))
self._tag_attention_dim = config["model"]["tag_attention_dim"]
self._tag_embed_dim = config["model"]["tag_embed_dim"]
self._tag_decoder_dim = config["model"]["tag_decoder_dim"]
self._decoder_dim = config["model"]["hidden_dim"]
self._dropout = config["model"]["dropout"]
self._bbox = config["train"]["bbox"]
self._bbox_attention_dim = config["model"]["bbox_attention_dim"]
self._bbox_embed_dim = config["model"]["bbox_embed_dim"]
self._bbox_decoder_dim = config["model"]["hidden_dim"]
self._enc_layers = config["model"]["enc_layers"]
self._dec_layers = config["model"]["dec_layers"]
self._n_heads = config["model"]["nheads"]
self._num_classes = config["model"]["bbox_classes"]
self._enc_image_size = config["model"]["enc_image_size"]
self._max_pred_len = config["predict"]["max_steps"]
self._tag_transformer = Tag_Transformer(
device,
tag_vocab_size,
td_encode,
self._decoder_dim,
self._enc_layers,
self._dec_layers,
self._enc_image_size,
n_heads=self._n_heads,
).to(device)
self._bbox_decoder = BBoxDecoder(
device,
self._bbox_attention_dim,
self._bbox_embed_dim,
self._tag_decoder_dim,
self._bbox_decoder_dim,
self._num_classes,
self._encoder_dim,
self._dropout,
).to(device)
def _log(self):
# Setup a custom logger
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
def mergebboxes(self, bbox1, bbox2):
new_w = (bbox2[0] + bbox2[2] / 2) - (bbox1[0] - bbox1[2] / 2)
new_h = (bbox2[1] + bbox2[3] / 2) - (bbox1[1] - bbox1[3] / 2)
new_left = bbox1[0] - bbox1[2] / 2
new_top = min((bbox2[1] - bbox2[3] / 2), (bbox1[1] - bbox1[3] / 2))
new_cx = new_left + new_w / 2
new_cy = new_top + new_h / 2
bboxm = torch.tensor([new_cx, new_cy, new_w, new_h])
return bboxm
def predict(self, imgs, max_steps, k, return_attention=False):
r"""
Inference.
The input image must be preprocessed and transformed.
Parameters
----------
img : tensor FloatTensor - torch.Size([1, 3, 448, 448])
Input image for the inference
Returns
-------
seq : list
Predictions for the tags as indices over the word_map
outputs_class : tensor(x, 3)
Classes of predicted bboxes. x is the number of bboxes. There are 3 bbox classes
outputs_coord : tensor(x, 4)
Coords of predicted bboxes. x is the number of bboxes. Each bbox is in [cxcywh] format
"""
AggProfiler().begin("predict_total", self._prof)
# Invoke encoder
self._tag_transformer.eval()
enc_out = self._encoder(imgs)
AggProfiler().end("model_encoder", self._prof)
word_map = self._init_data["word_map"]["word_map_tag"]
n_heads = self._tag_transformer._n_heads
# [1, 28, 28, 512]
encoder_out = self._tag_transformer._input_filter(
enc_out.permute(0, 3, 1, 2)
).permute(0, 2, 3, 1)
batch_size = encoder_out.size(0)
encoder_dim = encoder_out.size(-1)
enc_inputs = encoder_out.view(batch_size, -1, encoder_dim).to(self._device)
enc_inputs = enc_inputs.permute(1, 0, 2)
positions = enc_inputs.shape[0]
encoder_mask = torch.zeros(
(batch_size * n_heads, positions, positions), device=self._device
) == torch.ones(
(batch_size * n_heads, positions, positions), device=self._device
)
# Invoking tag transformer encoder before the loop to save time
AggProfiler().begin("model_tag_transformer_encoder", self._prof)
encoder_out = self._tag_transformer._encoder(enc_inputs, mask=encoder_mask)
AggProfiler().end("model_tag_transformer_encoder", self._prof)
decoded_tags = (
torch.LongTensor([word_map["<start>"]]).to(self._device).unsqueeze(1)
)
output_tags = []
cache = None
tag_H_buf = []
skip_next_tag = True
prev_tag_ucel = False
line_num = 0
# Populate bboxes_to_merge, indexes of first lcel, and last cell in a span
first_lcel = True
bboxes_to_merge = {}
cur_bbox_ind = -1
bbox_ind = 0
# i = 0
while len(output_tags) < self._max_pred_len:
decoded_embedding = self._tag_transformer._embedding(decoded_tags)
decoded_embedding = self._tag_transformer._positional_encoding(
decoded_embedding
)
AggProfiler().begin("model_tag_transformer_decoder", self._prof)
decoded, cache = self._tag_transformer._decoder(
decoded_embedding,
encoder_out,
cache,
memory_key_padding_mask=encoder_mask,
)
AggProfiler().end("model_tag_transformer_decoder", self._prof)
# Grab last feature to produce token
AggProfiler().begin("model_tag_transformer_fc", self._prof)
logits = self._tag_transformer._fc(decoded[-1, :, :]) # 1, vocab_size
AggProfiler().end("model_tag_transformer_fc", self._prof)
new_tag = logits.argmax(1).item()
# STRUCTURE ERROR CORRECTION
# Correction for first line xcel...
if line_num == 0:
if new_tag == word_map["xcel"]:
new_tag = word_map["lcel"]
# Correction for ucel, lcel sequence...
if prev_tag_ucel:
if new_tag == word_map["lcel"]:
new_tag = word_map["fcel"]
# End of generation
if new_tag == word_map["<end>"]:
output_tags.append(new_tag)
decoded_tags = torch.cat(
[
decoded_tags,
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
],
dim=0,
) # current_output_len, 1
break
output_tags.append(new_tag)
# BBOX PREDICTION
# MAKE SURE TO SYNC NUMBER OF CELLS WITH NUMBER OF BBOXes
if not skip_next_tag:
if new_tag in [
word_map["fcel"],
word_map["ecel"],
word_map["ched"],
word_map["rhed"],
word_map["srow"],
word_map["nl"],
word_map["ucel"],
]:
# GENERATE BBOX HERE TOO (All other cases)...
tag_H_buf.append(decoded[-1, :, :])
if first_lcel is not True:
# Mark end index for horizontal cell bbox merge
bboxes_to_merge[cur_bbox_ind] = bbox_ind
bbox_ind += 1
# Treat horisontal span bboxes...
if new_tag != word_map["lcel"]:
first_lcel = True
else:
if first_lcel:
# GENERATE BBOX HERE (Beginning of horisontal span)...
tag_H_buf.append(decoded[-1, :, :])
first_lcel = False
# Mark start index for cell bbox merge
cur_bbox_ind = bbox_ind
bboxes_to_merge[cur_bbox_ind] = -1
bbox_ind += 1
if new_tag in [word_map["nl"], word_map["ucel"], word_map["xcel"]]:
skip_next_tag = True
else:
skip_next_tag = False
# Register ucel in sequence...
if new_tag == word_map["ucel"]:
prev_tag_ucel = True
else:
prev_tag_ucel = False
decoded_tags = torch.cat(
[
decoded_tags,
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
],
dim=0,
) # current_output_len, 1
seq = decoded_tags.squeeze().tolist()
if self._bbox:
AggProfiler().begin("model_bbox_decoder", self._prof)
outputs_class, outputs_coord = self._bbox_decoder.inference(
enc_out, tag_H_buf
)
AggProfiler().end("model_bbox_decoder", self._prof)
else:
outputs_class, outputs_coord = None, None
outputs_class.to(self._device)
outputs_coord.to(self._device)
########################################################################################
# Merge First and Last predicted BBOX for each span, according to bboxes_to_merge
########################################################################################
outputs_class1 = []
outputs_coord1 = []
boxes_to_skip = []
for box_ind in range(len(outputs_coord)):
box1 = outputs_coord[box_ind].to(self._device)
cls1 = outputs_class[box_ind].to(self._device)
if box_ind in bboxes_to_merge:
box2 = outputs_coord[bboxes_to_merge[box_ind]].to(self._device)
boxes_to_skip.append(bboxes_to_merge[box_ind])
boxm = self.mergebboxes(box1, box2).to(self._device)
outputs_coord1.append(boxm)
outputs_class1.append(cls1)
else:
if box_ind not in boxes_to_skip:
outputs_coord1.append(box1)
outputs_class1.append(cls1)
if len(outputs_coord1) > 0:
outputs_coord1 = torch.stack(outputs_coord1)
if len(outputs_class1) > 0:
outputs_class1 = torch.stack(outputs_class1)
outputs_class = outputs_class1
outputs_coord = outputs_coord1
# Do the rest of the steps...
AggProfiler().end("predict_total", self._prof)
num_tab_cells = seq.count(4) + seq.count(5)
num_rows = seq.count(9)
self._log().info(
"OTSL predicted table cells#: {}; rows#: {}".format(num_tab_cells, num_rows)
)
return seq, outputs_class, outputs_coord
@@ -0,0 +1,203 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import math
from typing import Optional
import torch
from torch import Tensor, nn
import docling_ibm_models.tableformer.utils.utils as u
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=1024):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TMTransformerDecoder(nn.TransformerDecoder):
def forward(
self,
tgt: Tensor,
memory: Optional[Tensor] = None,
cache: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
tgt (Tensor): encoded tags. (tags_len,bsz,hidden_dim)
memory (Tensor): encoded image (enc_image_size,bsz,hidden_dim)
cache (Optional[Tensor]): None during training, only used during inference.
Returns:
output (Tensor): (tags_len,bsz,hidden_dim)
"""
output = tgt
# cache
tag_cache = []
for i, mod in enumerate(self.layers):
output = mod(output, memory)
tag_cache.append(output)
if cache is not None:
output = torch.cat([cache[i], output], dim=0)
if cache is not None:
out_cache = torch.cat([cache, torch.stack(tag_cache, dim=0)], dim=1)
else:
out_cache = torch.stack(tag_cache, dim=0)
return output, out_cache
class TMTransformerDecoderLayer(nn.TransformerDecoderLayer):
def forward(
self,
tgt: Tensor,
memory: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
same as TMTransformerDecoder
Returns:
Tensor:
During training (seq_len,bsz,hidden_dim)
If eval mode: embedding of last tag: (1,bsz,hidden_dim)
"""
# From PyTorch but modified to only use the last tag
tgt_last_tok = tgt[-1:, :, :]
tmp_tgt = self.self_attn(
tgt_last_tok,
tgt,
tgt,
attn_mask=None, # None, because we only care about the last tag
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
tgt_last_tok = self.norm1(tgt_last_tok)
if memory is not None:
tmp_tgt = self.multihead_attn(
tgt_last_tok,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
tgt_last_tok = self.norm2(tgt_last_tok)
tmp_tgt = self.linear2(
self.dropout(self.activation(self.linear1(tgt_last_tok)))
)
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
tgt_last_tok = self.norm3(tgt_last_tok)
return tgt_last_tok
class Tag_Transformer(nn.Module):
"""
"Attention Is All You Need" - https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
device,
vocab_size,
td_encode,
embed_dim,
encoder_layers,
decoder_layers,
enc_image_size,
dropout=0.1,
n_heads=4,
dim_ff=1024,
):
super(Tag_Transformer, self).__init__()
self._device = device
self._n_heads = n_heads
self._embedding = nn.Embedding(vocab_size, embed_dim)
self._positional_encoding = PositionalEncoding(embed_dim)
self._td_encode = td_encode
self._encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff
),
num_layers=encoder_layers,
)
self._decoder = TMTransformerDecoder(
TMTransformerDecoderLayer(
d_model=embed_dim,
nhead=n_heads,
dim_feedforward=dim_ff,
),
num_layers=decoder_layers,
)
self._decoder_dim = embed_dim
self._enc_image_size = enc_image_size
self._input_filter = u.resnet_block(stride=1)
self._fc = nn.Linear(embed_dim, vocab_size)
def inference(self, enc_inputs, tags, tag_lens, num_cells):
# CNN backbone image encoding
enc_inputs = self._input_filter(enc_inputs.permute(0, 3, 1, 2)).permute(
0, 2, 3, 1
)
batch_size = enc_inputs.size(0)
encoder_dim = enc_inputs.size(-1)
enc_inputs = enc_inputs.view(batch_size, -1, encoder_dim).to(self._device)
enc_inputs = enc_inputs.permute(1, 0, 2)
positions = enc_inputs.shape[0]
# Transformer Encoder Encoded Image mask need to check if its useful
encoder_mask = torch.zeros(
(batch_size * self._n_heads, positions, positions), device=self._device
) == torch.ones(
(batch_size * self._n_heads, positions, positions), device=self._device
)
# Transformer Encoder
encoder_out = self._encoder(enc_inputs, mask=encoder_mask)
decode_lengths = (tag_lens - 1).tolist()
tgt = self._positional_encoding(self._embedding(tags).permute(1, 0, 2))
decoded = self._decoder(tgt, memory=encoder_out)
decoded = decoded.permute(1, 0, 2)
predictions = self._fc(decoded)
return predictions, decode_lengths
+541
View File
@@ -0,0 +1,541 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import copy
import logging
from itertools import groupby
import docling_ibm_models.tableformer.settings as s
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
logger = s.get_custom_logger("consolidate", LOG_LEVEL)
png_files = {} # Evaluation files
total_pics = 0
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def otsl_clean(rs_list):
new_rs_list = []
stop_list = ["<pad>", "<unk>", "<start>", "<end>"]
for tag in rs_list:
if tag not in stop_list:
new_rs_list.append(tag)
return new_rs_list
def otsl_sqr_chk(rs_list, name, logdebug):
rs_list_split = [
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
]
isSquare = True
if len(rs_list_split) > 0:
init_tag_len = len(rs_list_split[0]) + 1
for ind, ln in enumerate(rs_list_split):
ln.append("nl")
if len(ln) != init_tag_len:
isSquare = False
if isSquare:
if logdebug:
print(
"{}*OK* Table is square! *OK*{}".format(
bcolors.OKGREEN, bcolors.ENDC
)
)
else:
err_name = "{}*ERR* " + name + " *ERR*{}"
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
print(
"{}*ERR* Table is not square! *ERR*{}".format(
bcolors.FAIL, bcolors.ENDC
)
)
return isSquare
def otsl_pad_to_sqr(rs_list, pad_tag):
new_list = []
rs_list_split = [
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
]
max_row_len = 0
for ind, ln in enumerate(rs_list_split):
if len(ln) > max_row_len:
max_row_len = len(ln)
for ind, ln in enumerate(rs_list_split):
ln += [pad_tag] * (max_row_len - len(ln))
ln.append("nl")
new_list.extend(ln)
return new_list
def otsl_tags_cells_sync_chk(rs_list, cells, name, logdebug):
countCellTags = 0
isGood = True
for rsTag in rs_list:
if rsTag in ["fcel", "ched", "rhed", "srow", "ecel"]:
countCellTags += 1
if countCellTags != len(cells):
err_name = "{}*!ERR* " + name + " *ERR!*{}"
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
err_msg = "{}*!ERR* Tags are not in sync with cells! *ERR!*{}"
print(err_msg.format(bcolors.FAIL, bcolors.ENDC))
isGood = False
return isGood
def otsl_check_down(rs_split, x, y):
distance = 1
elem = "ucel"
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "lcel", "nl"]
while elem not in goodlist and y < len(rs_split) - 1:
y += 1
distance += 1
elem = rs_split[y][x]
if elem in goodlist:
distance -= 1
return distance
def otsl_check_right(rs_split, x, y):
distance = 1
elem = "lcel"
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "ucel", "nl"]
while elem not in goodlist and x < (len(rs_split[y]) - 1):
x += 1
distance += 1
elem = rs_split[y][x]
if elem in goodlist:
distance -= 1
return distance
def otsl_to_html(rs_list, logdebug):
if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]:
# Most likely already HTML...
return rs_list
html_table = []
if logdebug:
print("{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC))
if not otsl_sqr_chk(rs_list, "---", logdebug):
# PAD TABLE TO SQUARE
print("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
rs_list = otsl_pad_to_sqr(rs_list, "lcel")
# 2D structure, line by line:
rs_list_split = [
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
]
if logdebug:
print("")
# Sequentially store indexes of 2D spans that were registered to avoid re-registering them
registry_2d_span = []
# Iterate all elements in the rs line, and look right / down to detect spans
# If span detected - run function to find size of the span
# repeat with all cells
thead_present = False
for rs_row_ind, rs_row in enumerate(rs_list_split):
html_list = []
if not thead_present:
if "ched" in rs_list_split[rs_row_ind]:
html_list.append("<thead>")
thead_present = True
if thead_present:
if "ched" not in rs_list_split[rs_row_ind]:
html_list.append("</thead>")
thead_present = False
html_list.append("<tr>")
for rs_cell_ind, rs_cell in enumerate(rs_list_split[rs_row_ind]):
if rs_cell in ["fcel", "ched", "rhed", "srow", "ecel"]:
rdist = 0
ddist = 0
xrdist = 0
xddist = 0
span = False
# Check if it has horizontal span:
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "lcel":
rdist = otsl_check_right(rs_list_split, rs_cell_ind, rs_row_ind)
span = True
# Check if it has vertical span:
if rs_row_ind + 1 < len(rs_list_split):
# print(">>>")
# print(rs_list_split[rs_row_ind + 1])
# print(">>> rs_cell_ind = {}".format(rs_cell_ind))
if rs_list_split[rs_row_ind + 1][rs_cell_ind] == "ucel":
ddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
span = True
# Check if it has 2D span:
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "xcel":
xrdist = otsl_check_right(
rs_list_split, rs_cell_ind, rs_row_ind
)
xddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
span = True
# Check if this 2D span was already registered,
# If not - register, if yes - cancel span
# print("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
# print("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
for y in range(rs_row_ind, xddist + rs_row_ind):
reg2dind = str(x) + "_" + str(y)
# print(reg2dind)
if reg2dind in registry_2d_span:
# Cell of the span is already in, cancel current span
span = False
if span:
# None of the span cells were previously registered
# Register an entire span
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
for y in range(rs_row_ind, xddist + rs_row_ind):
reg2dind = str(x) + "_" + str(y)
registry_2d_span.append(reg2dind)
if span:
html_list.append("<td")
if rdist > 1:
html_list.append(' colspan="' + str(rdist) + '"')
if ddist > 1:
html_list.append(' rowspan="' + str(ddist) + '"')
if xrdist > 1:
html_list.append(' rowspan="' + str(xddist) + '"')
html_list.append(' colspan="' + str(xrdist) + '"')
html_list.append(">")
html_list.append("</td>")
else:
html_list.append("<td>")
html_list.append("</td>")
html_list.append("</tr>")
html_table.extend(html_list)
if logdebug:
print("*********************** registry_2d_span ***************************")
print(registry_2d_span)
print("********************************************************************")
return html_table
def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer):
r"""
Converts table structure from HTML to RS
Parameters
----------
table : json
line from jsonl
writer : writer
Writes lines into output jsonl
"""
table_html_structure = copy.deepcopy(table["html"]["structure"])
out_line = table
if include_html:
out_line["html"]["html_structure"] = table_html_structure
out_line["html"]["html_restored_structure"] = {"tokens": []}
out_line["html"]["structure"] = {"tokens": []}
# possible colspans
pos_colspans = {
' colspan="20"': 20,
' colspan="19"': 19,
' colspan="18"': 18,
' colspan="17"': 17,
' colspan="16"': 16,
' colspan="15"': 15,
' colspan="14"': 14,
' colspan="13"': 13,
' colspan="12"': 12,
' colspan="11"': 11,
' colspan="10"': 10,
' colspan="2"': 2,
' colspan="3"': 3,
' colspan="4"': 4,
' colspan="5"': 5,
' colspan="6"': 6,
' colspan="7"': 7,
' colspan="8"': 8,
' colspan="9"': 9,
}
# possible rowspans
pos_rowspans = {
' rowspan="20"': 20,
' rowspan="19"': 19,
' rowspan="18"': 18,
' rowspan="17"': 17,
' rowspan="16"': 16,
' rowspan="15"': 15,
' rowspan="14"': 14,
' rowspan="13"': 13,
' rowspan="12"': 12,
' rowspan="11"': 11,
' rowspan="10"': 10,
' rowspan="2"': 2,
' rowspan="3"': 3,
' rowspan="4"': 4,
' rowspan="5"': 5,
' rowspan="6"': 6,
' rowspan="7"': 7,
' rowspan="8"': 8,
' rowspan="9"': 9,
}
t_cells = [] # 2D structure
tl_cells = [] # 1D structure
t_expands = [] # 2D structure
tl_spans = {} # MAP, POPULATE WITH ACTUAL SPANS VALUES, IN SYNC WITH tl_cells
current_line = 0
current_column = 0
current_html_cell_ind = 0
current_line_tags = []
current_line_expands = []
if logdebug:
print("")
print("*** {}: {} ***".format(table["split"], table["filename"]))
colnum = 0
if extra_debug:
print("========================== Input HTML ============================")
print(table_html_structure["tokens"])
print("==================================================================")
if logdebug:
print("********")
print("* OTSL *")
print("********")
for i in range(len(table_html_structure["tokens"])):
html_tag = table_html_structure["tokens"][i]
prev_html_tag = ""
next_html_tag = ""
if i > 0:
prev_html_tag = table_html_structure["tokens"][i - 1]
if i < len(table_html_structure["tokens"]) - 1:
next_html_tag = table_html_structure["tokens"][i + 1]
if html_tag not in ["<thead>", "<tbody>"]:
# Then check the next tag...
# rules of conversion
# Check up-cell in t_expands, in case row-spans have to be inserted
if html_tag in ["<td>", "<td", "</tr>"]:
if current_line > 0:
if current_column >= len(t_expands[current_line - 1]):
# !!!
return False, {}
up_expand = t_expands[current_line - 1][current_column]
while up_expand[1] > 0:
if up_expand[0] == 0:
# ucel
current_line_tags.append("ucel")
current_line_expands.append([0, up_expand[1] - 1])
current_column += 1
else:
# xcel
for ci in range(up_expand[0]):
current_line_tags.append("xcel")
current_line_expands.append(
[up_expand[0] - ci, up_expand[1] - 1]
)
current_column += 1
up_expand = t_expands[current_line - 1][current_column]
# ======================================================================================
# Fix for trailing "ucel" in a row
if html_tag in ["</tr>"]:
if current_line > 0:
cur_line_len = len(current_line_expands)
pre_line_len = len(t_expands[current_line - 1])
if cur_line_len < pre_line_len:
extra_columns = pre_line_len - cur_line_len - 1
if extra_columns > 0:
if extra_debug:
print(
"Extra columns needed in row: {}".format(
extra_columns
)
)
for clm in range(extra_columns):
up_expand = t_expands[current_line - 1][
cur_line_len + clm
]
if up_expand[0] == 0:
# ucel
current_line_tags.append("ucel")
current_line_expands.append([0, up_expand[1] - 1])
else:
# xcel
current_line_tags.append("xcel")
current_line_expands.append(
[up_expand[0], up_expand[1] - 1]
)
# ======================================================================================
# 1. Opening cell tags
if html_tag in ["<td>", "<td"]:
# check if cell is empty...
cell_is_empty = True
if "cells" in table["html"]:
cell_tokens = table["html"]["cells"][current_html_cell_ind][
"tokens"
]
else:
cell_tokens = "f"
# Clean cell_tokens from trash:
cell_tokens = list(filter(lambda a: a != "<i>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "<I>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "<b>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "<B>", cell_tokens))
cell_tokens = list(filter(lambda a: a != " ", cell_tokens))
cell_tokens = list(filter(lambda a: a != "</b>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "</B>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "</i>", cell_tokens))
cell_tokens = list(filter(lambda a: a != "</I>", cell_tokens))
# Check if empty
if len(cell_tokens) > 0:
cell_is_empty = False
if cell_is_empty:
out_line["html"]["cells"][current_html_cell_ind]["tokens"] = []
current_line_tags.append("ecel")
current_line_expands.append([0, 0])
else:
current_line_tags.append("fcel")
current_line_expands.append([0, 0])
current_html_cell_ind += 1
current_column += 1
# 2. Closing row tags
if html_tag == "</tr>":
if len(current_line_tags) > colnum:
colnum = len(current_line_tags)
# Save everything we read about the line to t_cells
current_line_tags.append("nl")
t_cells.append(copy.deepcopy(current_line_tags))
tl_cells.extend(copy.deepcopy(current_line_tags))
if logdebug:
print(current_line_tags)
current_line_tags = []
# Deal with expands
current_line_expands.append([-1, -1])
# Output spans metadata
t_expands.append(copy.deepcopy(current_line_expands))
current_line_expands = []
current_column = 0
current_line += 1
# 3. Colspans only
if html_tag in pos_colspans:
if prev_html_tag not in pos_rowspans:
if next_html_tag not in pos_rowspans:
colspan_len = pos_colspans[html_tag]
tl_spans[current_html_cell_ind - 1] = [colspan_len, 1]
current_line_expands[len(current_line_expands) - 1] = [
colspan_len,
0,
]
for ci in range(colspan_len - 1):
current_line_tags.append("lcel")
current_line_expands.append([colspan_len - ci - 1, 0])
current_column += 1
# 4. Rowspans only
if html_tag in pos_rowspans:
if prev_html_tag not in pos_colspans:
if next_html_tag not in pos_colspans:
rowspan_len = pos_rowspans[html_tag]
tl_spans[current_html_cell_ind - 1] = [1, rowspan_len]
current_line_expands[len(current_line_expands) - 1] = [
0,
rowspan_len - 1,
]
# 5. 2D spans
if html_tag in pos_rowspans:
rowspan_len = pos_rowspans[html_tag]
if prev_html_tag in pos_colspans:
colspan_len = pos_colspans[prev_html_tag]
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
newexp = [colspan_len, rowspan_len - 1]
current_line_expands[len(current_line_expands) - 1] = newexp
for ci in range(colspan_len - 1):
current_line_tags.append("xcel")
current_line_expands.append(
[colspan_len - ci - 1, rowspan_len - 1]
)
if next_html_tag in pos_colspans:
colspan_len = pos_colspans[next_html_tag]
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
newexp = [colspan_len, rowspan_len - 1]
current_line_expands[len(current_line_expands) - 1] = newexp
for ci in range(colspan_len - 1):
current_line_tags.append("xcel")
current_line_expands.append(
[colspan_len - ci - 1, rowspan_len - 1]
)
t_name = "*** {}: {} ***".format(table["split"], table["filename"])
# check if square
isSquare = otsl_sqr_chk(tl_cells, t_name, logdebug)
# TODO: pad if not square?
if not isSquare:
tl_cells = otsl_pad_to_sqr(tl_cells, "fcel")
# check if cells (bboxes) in sync:
if "cells" in out_line["html"]:
isGood = otsl_tags_cells_sync_chk(
tl_cells, out_line["html"]["cells"], t_name, logdebug
)
# convert back to HTML
rHTML = []
if isSquare:
rHTML = otsl_to_html(tl_cells, logdebug)
out_line["html"]["html_restored_structure"]["tokens"] = rHTML
out_line["html"]["structure"]["tokens"] = tl_cells
out_line["otsl_spans"] = tl_spans
out_line["cols"] = colnum
out_line["rows"] = len(t_cells)
out_line["html_len"] = len(table_html_structure["tokens"])
out_line["rs_len"] = len(tl_cells)
# save converted line
if use_writer:
if isSquare:
if isGood:
writer.write(out_line)
if logdebug:
print("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
print(rHTML)
# original HTML
oHTML = out_line["html"]["html_structure"]
print("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
print(oHTML)
return True, out_line
@@ -0,0 +1,90 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import sys
def get_custom_logger(logger_name, level, stream=sys.stdout):
r"""
Create a custom logger with a standard formatting
Inputs:
- logger_name: Name of the logger. You can get the class name as self.__class__.__name__
- level: logging level (e.g. logging.INFO, logging.DEBUG, etc.)
- stream: One of sys.stdout or sys.stderr
Outputs:
logger
"""
logger = logging.getLogger(logger_name)
logger.setLevel(level)
# Set the handler
if not logger.hasHandlers():
handler = logging.StreamHandler(stream)
formatter = logging.Formatter(
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
###################################################################################
# System constants
#
r"""
This is a "generic" logger available to all scripts.
It is encouraged that each class has it's own custom logger with the name of the class.
You can use the "get_custom_logger" function to build a custom logger with a standard format.
"""
LOGGER = get_custom_logger("docling-pm", logging.INFO)
# Supported dataset types
supported_datasets = ["TF_prepared"] # TF prepared dataset
# Split names
TRAIN_SPLIT = "train"
VAL_SPLIT = "val"
TEST_SPLIT = "test"
# Prepared data parts and filename templates
PREPARED_DATA_PARTS = {
# Array with the bboxes (x1y1x2y2) for all cells of the images across all splits.
# The bboxes are indexed with the filename.
# Notices:
# - The bboxes are NOT transformed.
# - If the image filenames are the same across splits, there will be one one entry in the file
"BBOXES": "BBOXES.json",
# Image filenames used for train and val
"IMAGES": "IMAGES.json",
# Mean, std, variance as arrays of 3 (for each color)
"STATISTICS": "STATISTICS_<POSTFIX>.json", # PRECOMPUTED
# Bboxes of the cells in the form [1, x1, x2, y1, y2] or [0, 0, 0, 0, 0] in case of no box.
"TRAIN_CELLBBOXES": "TRAIN_CELLBBOXES_<POSTFIX>.json", # NOT USED.
# Array with arrays of the length + 2 of the original cells per image.
"TRAIN_CELLLENS": "TRAIN_CELLLENS_<POSTFIX>.json",
# Indices of the cells between <start> <end> and <pad> at the end.
"TRAIN_CELLS": "TRAIN_CELLS_<POSTFIX>.json",
# Array with the length + 2 of the original tags per image.
"TRAIN_TAGLENS": "TRAIN_TAGLENS_<POSTFIX>.json",
# Indices of the tags between <start> <end> and <pad> at the end.
"TRAIN_TAGS": "TRAIN_TAGS_<POSTFIX>.json",
# Ground truth for the evaluation dataset per eval image.
"VAL": "VAL.json",
# Vocabulary: Indices of the word_map_cells and word_map_tags
"WORDMAP": "WORDMAP_<POSTFIX>.json", # PRECOMPUTED
}
# Purposes
TRAIN_PURPOSE = "train"
VAL_PURPOSE = "val"
TEST_PURPOSE = "test"
PREDICT_PURPOSE = "predict"
# The DDP world size when we train in CPU with DDP enabled
DDP_CPU_WORLD_SIZE = 2
@@ -0,0 +1,37 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import docling_ibm_models.tableformer.common as c
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
def dataset_test(config):
r"""
Parameters
----------
config : dictionary
The configuration settings
"""
# model_type = config["model"]["type"]
# Create the device and the Dataset
device = "cpu"
dataset = TFDataset(config, "train")
dataset.set_device(device)
# Loop over the data
dataset.reset()
dataset.shuffle()
for i, batch in enumerate(dataset):
print("Loading batch: {}".format(i))
if __name__ == "__main__":
config = c.parse_arguments()
dataset_test(config)
@@ -0,0 +1,99 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import glob
import os
import numpy as np
from PIL import Image
import docling_ibm_models.tableformer.common as c
from docling_ibm_models.tableformer.data_management.data_transformer import (
DataTransformer,
)
def dump_np(img_np: np.array, fn, n=6):
# Expect to receive a numpy array for an image with the shape [channels, rows, columns]
s = img_np.shape
if s[0] not in [1, 2, 3, 4] or len(s) != 3:
print("Image of invalid shape: {}".format(s))
return
channels = s[0]
rows = s[1]
cols = s[2]
w = n + 6
with open(fn, "w") as fd:
for r in range(rows):
for col in range(cols):
for ch in range(channels):
x = img_np[ch][r][col]
if isinstance(x, np.float32):
f_str = "0:>{}.{}f".format(w, n)
elif isinstance(x, np.uint8):
f_str = "0:>{}".format(w)
else:
return False
x_str = ("{" + f_str + "}").format(x)
fd.write(x_str)
if ch < channels - 1:
fd.write(" ")
fd.write("\n")
return True
def dump_channels(save_dir, fn_prefix, img_np: np.array):
# Dump the np array into 3 files per channel
img_np_ch0 = img_np[0, :, :]
img_np_ch1 = img_np[1, :, :]
img_np_ch2 = img_np[2, :, :]
txt_ch0_fn = os.path.join(save_dir, fn_prefix + "_ch0.txt")
txt_ch1_fn = os.path.join(save_dir, fn_prefix + "_ch1.txt")
txt_ch2_fn = os.path.join(save_dir, fn_prefix + "_ch2.txt")
np.savetxt(txt_ch0_fn, img_np_ch0)
np.savetxt(txt_ch1_fn, img_np_ch1)
np.savetxt(txt_ch2_fn, img_np_ch2)
print(f"{txt_ch0_fn}")
print(f"{txt_ch1_fn}")
print(f"{txt_ch2_fn}")
def prepare_image(config):
transformer = DataTransformer(config)
predict_dir = config["predict"]["predict_dir"]
use_normalization = config["dataset"]["image_normalization"]["state"]
pattern = os.path.join(predict_dir, "*.png")
for img_fn in glob.glob(pattern):
print(f"img_fn: {img_fn}")
with Image.open(img_fn) as img:
# Dump the initial image in txt files
img_np = np.array(img)
# Reshape the image in order to print it
img_np_m = np.moveaxis(img_np, 2, 0)
print(
"orig. img_np.shape: {}, reshaped image: {}".format(
img_np.shape, img_np_m.shape
)
)
original_fn = img_fn + "_python.txt"
dump_np(img_np_m, original_fn)
r_img_ten = transformer.rescale_in_memory(img, use_normalization)
print("npimgc: {} - {}".format(r_img_ten.type(), r_img_ten.size()))
# Dump the processed image tensor in txt files
r_img_np = r_img_ten.numpy()
prepared_fn = img_fn + "_python_prepared.txt"
dump_np(r_img_np, prepared_fn)
if __name__ == "__main__":
config = c.parse_arguments()
prepare_image(config)
@@ -0,0 +1,243 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import time
from collections import deque
from statistics import mean, median
class SingletonClass(type):
r"""
Generic singleton metaclass
"""
def __init__(self, name, bases, dic):
self._instance = None
super().__init__(name, bases, dic)
def __call__(cls, *args, **kwargs):
# Create a singleton if needed
if cls._instance is None:
singleton = cls.__new__(cls)
singleton.__init__(*args, **kwargs)
cls._instance = singleton
return cls._instance
class Profiler:
r"""
Application specific profiler
Decompose the application into "sections". Each section is a label.
The total time a section consumes is split into "intervals"
Use the `begin`, `end` methods to mark the begining and end of an interval for
a certain section
"""
def __init__(self):
self._section_dts = {} # section name -> sum(section intervals)
self._section_calls = {} # section name -> number of invocations
self._section_kB = {} # section name -> max kB of used heap
# section name -> beginning of the last interval
self._last_begin = {}
def begin(self, section_name, enable=True):
r"""
Mark the beginning of an interval
Parameters
----------
section_name : string
Name of the section
enable : bool
The actual interval entry takes place only if enable is true
Return
------
True if the interval has actuall begun
"""
if not enable:
return False
self._last_begin[section_name] = time.time()
return True
def end(self, section_name, enable=True):
r"""
Mark the end of an interval for a certain section
Parameters
----------
section_name : string
Name of the section
enable : bool
The actual interval entry takes place only if enable is true
Return
------
True if the section name is valid and an interval for this section has already begun
False otherwise
"""
if not enable:
return False
if section_name not in self._last_begin:
return False
dt = time.time() - self._last_begin[section_name]
if section_name not in self._section_dts:
self._section_dts[section_name] = dt
self._section_calls[section_name] = 1
else:
self._section_dts[section_name] += dt
self._section_calls[section_name] += 1
return True
def get_data(self, section_names=None):
r"""
Return a dict with profiling data for the specified sections.
Parameter
---------
section_names : list of string
List with the section names to get their accumulative dt
If it is None, all sections are returned
Return
------
dict of dicts
Outer key: section name
Inner keys: "dt": Accumulative time for that section, "cells": Number of calls
"""
# Filter the section names to apply
filtered_names = list(
filter(lambda x: x in section_names, self._section_dts.keys())
if section_names is not None
else self._section_dts.keys()
)
data = {}
for section_name in filtered_names:
data[section_name] = {
"dt": self._section_dts[section_name],
"calls": self._section_calls[section_name],
"kB": self._section_kB[section_name],
}
return data
class AppProfiler(Profiler, metaclass=SingletonClass):
r"""
AppProfiler is a singleton of the Profiler for application wide usage
"""
def __init__(self):
super(AppProfiler, self).__init__()
class AggProfiler(metaclass=SingletonClass):
r"""
Generic wrapper of Profiler that enables aggregation of profiling statistics around Cycles
- When a new cycle begins a new Profiler is created to keep the profiling data per section
- Keep the last n cycles in a sliding window manner
- At every time we can get profiling data about the last cycle and statistics over the last n
cycles
"""
def __init__(self, window_size=20):
self._window_size = window_size
# deque with up to the last "window_size" Profilers. The newest at index 0
self._cycles = deque()
def start_agg(self, enable=True):
r"""
Returns
-------
0: not enabled
1: a new scope has started
"""
if not enable:
return 0
# Add a new profiler
self._cycles.appendleft(Profiler())
# In case the deque has grown too much, remove the oldest Profiler
if len(self._cycles) > self._window_size:
self._cycles.pop()
return 1
def begin(self, section_name, enable=True):
if not enable:
return False
if len(self._cycles) == 0:
print("AggProfiler begin | Start Aggregator not initialized.")
return False
profiler = self._cycles[0]
return profiler.begin(section_name)
def end(self, section_name, enable=True):
if not enable:
return False
if len(self._cycles) == 0:
print("AggProfiler end | Start Aggregator not initialized.")
return False
profiler = self._cycles[0]
return profiler.end(section_name)
def get_data(self):
r"""
Get profiling data for:
- The last cycle
- Aggragated statistics (avg, median) per section and per metric across all cycles
- The dt numbers for the mean/median is the average time for each section ACROSS the cycle
- There is NO need to compute average by yourself.
Returns
-------
dict with the structure:
- window: int with the size of the time sliding window
- last: dict with the metrics for the last cycle (as provided by the Profiler)
- mean: dict with the mean metrics per section across the cycle
- section_name
- metric_name: mean of the metric values
- median: dict with the median metrics per section across the cycle
- section_name
- metric_name: median of the metric values
"""
last_data = self._cycles[0].get_data()
data = {
"window": len(self._cycles),
"last": last_data,
"mean": {},
"median": {},
}
# Section -> metric -> [values]
section_metric_values = {}
# Collect the metrics
for i, p in enumerate(self._cycles):
p_data = p.get_data()
for section_name, m_dict in p_data.items():
for m_name, m_val in m_dict.items():
if section_name not in section_metric_values:
section_metric_values[section_name] = {}
s_metrics = section_metric_values[section_name]
if m_name not in s_metrics:
s_metrics[m_name] = []
s_metrics[m_name].append(m_val)
# Aggregate the metrics
for section_name, m_dict in section_metric_values.items():
for m_name, m_values in m_dict.items():
if section_name not in data["mean"]:
data["mean"][section_name] = {}
if section_name not in data["median"]:
data["median"][section_name] = {}
mean_v = mean(m_values)
median_v = median(m_values)
data["mean"][section_name][m_name] = mean_v
data["median"][section_name][m_name] = median_v
return data
@@ -0,0 +1,216 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import torch
def model_info(model, verbose=False):
# Plots a line-by-line description of a PyTorch model
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(
x.numel() for x in model.parameters() if x.requires_grad
) # number gradients
if verbose:
print(
"%5s %40s %9s %12s %20s %10s %10s"
% ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
)
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace("module_list.", "")
print(
"%5g %40s %9s %12g %20s %10.3g %10.3g"
% (
i,
name,
p.requires_grad,
p.numel(),
list(p.shape),
p.mean(),
p.std(),
)
)
try: # FLOPS
from thop import profile
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
fs = ", %.1f GFLOPS" % (macs / 1e9 * 2)
except Exception:
fs = ""
print(
"Model Summary: %g layers, %g parameters, %g gradients%s"
% (len(list(model.parameters())), n_p, n_g, fs)
)
# def init_seeds(seed=0):
# torch.manual_seed(seed)
#
# # Reduce randomness (may be slower on Tesla GPUs)
# # https://pytorch.org/docs/stable/notes/randomness.html
# if seed == 0:
# cudnn.deterministic = False
# cudnn.benchmark = True
#
#
# def select_device(device='', apex=False, batch_size=None):
# # device = 'cpu' or '0' or '0,1,2,3'
# cpu_request = device.lower() == 'cpu'
# if device and not cpu_request: # if device requested other than 'cpu'
# os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
# # check availablity
# assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
#
# cuda = False if cpu_request else torch.cuda.is_available()
# if cuda:
# c = 1024 ** 2 # bytes to MB
# ng = torch.cuda.device_count()
# if ng > 1 and batch_size: # check that batch_size is compatible with device_count
# assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % \
# (batch_size, ng)
# x = [torch.cuda.get_device_properties(i) for i in range(ng)]
# # apex for mixed precision https://github.com/NVIDIA/apex
# s = 'Using CUDA ' + ('Apex ' if apex else '')
# for i in range(0, ng):
# if i == 1:
# s = ' ' * len(s)
# print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
# (s, i, x[i].name, x[i].total_memory / c))
# else:
# print('Using CPU')
#
# print('') # skip a line
# return torch.device('cuda:0' if cuda else 'cpu')
#
#
# def time_synchronized():
# torch.cuda.synchronize() if torch.cuda.is_available() else None
# return time.time()
#
#
# def initialize_weights(model):
# for m in model.modules():
# t = type(m)
# if t is nn.Conv2d:
# pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif t is nn.BatchNorm2d:
# m.eps = 1e-4
# m.momentum = 0.03
# elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
# m.inplace = True
#
#
# def find_modules(model, mclass=nn.Conv2d):
# # finds layer indices matching module class 'mclass'
# return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
#
#
# def fuse_conv_and_bn(conv, bn):
# # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
# with torch.no_grad():
# # init
# fusedconv = torch.nn.Conv2d(conv.in_channels,
# conv.out_channels,
# kernel_size=conv.kernel_size,
# stride=conv.stride,
# padding=conv.padding,
# bias=True)
#
# # prepare filters
# w_conv = conv.weight.clone().view(conv.out_channels, -1)
# w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
# fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
#
# # prepare spatial bias
# if conv.bias is not None:
# b_conv = conv.bias
# else:
# b_conv = torch.zeros(conv.weight.size(0))
# b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
# fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
#
# return fusedconv
#
#
# def load_classifier(name='resnet101', n=2):
# # Loads a pretrained model reshaped to n-class output
# import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
# model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
#
# # Display model properties
# for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean',
# 'model.std']:
# print(x + ' =', eval(x))
#
# # Reshape output to n classes
# filters = model.last_linear.weight.shape[1]
# model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
# model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
# model.last_linear.out_features = n
# return model
#
#
# def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
# # scales img(bs,3,y,x) by ratio
# h, w = img.shape[2:]
# s = (int(h * ratio), int(w * ratio)) # new size
# img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
# if not same_shape: # pad/crop img
# gs = 64 # (pixels) grid size
# h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
# return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
#
#
# class ModelEMA:
# """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
# Keep a moving average of everything in the model state_dict (parameters and buffers).
# This is intended to allow functionality like
# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# A smoothed version of the weights is necessary for some training schemes to perform well.
# E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
# RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
# smoothing of weights to match results. Pay attention to the decay constant you are using
# relative to your update count per epoch.
# To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
# disable validation of the EMA weights. Validation will have to be done manually in a separate
# process, or after the training stops converging.
# This class is sensitive where it is initialized in the sequence of model init,
# GPU assignment and distributed training wrappers.
# I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and
# single-GPU.
# """
#
# def __init__(self, model, decay=0.9999, device=''):
# # make a copy of the model for accumulating moving average of weights
# self.ema = deepcopy(model)
# self.ema.eval()
# self.updates = 0 # number of EMA updates
# # decay exponential ramp (to help early epochs)
# self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
# self.device = device # perform ema on different device from model if set
# if device:
# self.ema.to(device=device)
# for p in self.ema.parameters():
# p.requires_grad_(False)
#
# def update(self, model):
# self.updates += 1
# d = self.decay(self.updates)
# with torch.no_grad():
# if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
# msd, esd = model.module.state_dict(), self.ema.module.state_dict()
# else:
# msd, esd = model.state_dict(), self.ema.state_dict()
#
# for k, v in esd.items():
# if v.dtype.is_floating_point:
# v *= d
# v += (1. - d) * msd[k].detach()
#
# def update_attr(self, model):
# # Assign attributes (which may change during training)
# for k in model.__dict__.keys():
# if not k.startswith('_'):
# setattr(self.ema, k, getattr(model, k))
@@ -0,0 +1,376 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models.resnet import BasicBlock, conv1x1
from torchvision.ops.boxes import box_area
def remove_padding(seq):
r"""
Remove the trailing zeros from the provided input
Parameters
----------
list: List of integers
Predicted sequence
Returns
-------
list: List of integers
The part of the input before the zero padding
"""
pad_len = 0
for x in reversed(seq):
if x != 0:
break
pad_len += 1
if pad_len == 0:
return seq, 0
un_padded = seq[:-pad_len]
return un_padded, pad_len
def probabilities_to_predictions(probabilities):
r"""
Convert probabilities to predictions
Parameters
----------
probabilities : Tensor[batch_size, vocab_size, seq_len]
All log probabilities coming out at the last stage of the decoder
Returns
-------
predictions : tensor [batch_size, output_sequence_length]
The prediceted trags
"""
# max_idx: [batch_size, seq_len]
max_idx = torch.argmax(probabilities, dim=1)
return max_idx
def print_target_predict(target, predictions, filenames=None, batch_idx=0):
r"""
For the Tags, print the target and predicted tensors for the specified batch index
We expect to have the batch size as the first dimension.
Only the specified batch is extractred and the remaining dimenions are flattened.
The results are printed as 2 lists with the target on top and the predictions below underlined
Parameters
---------
target : tensor [batch_size, output_sequence_length]
The ground truth tags
predictions : tensor [batch_size, output_sequence_length]
The prediceted trags
filenames : list of string
The actual filename that provides the data
batch_idx : int
Which index in the batch dimension will be printed
"""
target_flat = target[batch_idx].flatten()
predictions_flat = predictions[batch_idx].flatten()
target_label = "target"
predict_label = "predict"
if filenames is not None:
target_label = filenames[batch_idx]
label_len = max(len(target_label), len(predict_label))
print("{}: {}".format(target_label.ljust(label_len, " "), target_flat.tolist()))
print(
"{}: {}".format(predict_label.ljust(label_len, " "), predictions_flat.tolist())
)
def load_image(full_fn):
r"""
Load an image from the disk as a numpy array
Parameters
----------
full_fn : string
The full path filename of the image
Results
-------
img : numpy array: (channels, width, height)
The loaded image as a numpy array
"""
with Image.open(full_fn) as f:
img = np.asarray(f) # (width, height, channels)
img = img.transpose(2, 0, 1) # (channels, width, height)
return img
def resnet_block(stride=1):
layers = []
downsample = nn.Sequential(
conv1x1(256, 512, stride),
nn.BatchNorm2d(512),
)
layers.append(BasicBlock(256, 512, stride, downsample))
layers.append(BasicBlock(512, 512, 1))
return nn.Sequential(*layers)
def repackage_hidden(h):
r"""
Wraps hidden states in new Tensors, to detach them from their history.
"""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def accuracy(scores, targets, k):
"""
Computes top-k accuracy, from predicted and true labels.
:param scores: scores from the model
:param targets: true labels
:param k: k in top-k accuracy
:return: top-k accuracy
"""
batch_size = targets.size(0)
_, ind = scores.topk(k, 1, True, True)
correct = ind.eq(targets.view(-1, 1).expand_as(ind))
correct_total = correct.view(-1).float().sum() # 0D tensor
return correct_total.item() * (100.0 / batch_size)
def clip_gradient(optimizer, grad_clip):
"""
Clips gradients computed during backpropagation to avoid explosion of gradients.
:param optimizer: optimizer with the gradients to be clipped
:param grad_clip: clip value
"""
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
class AverageMeter(object):
"""
Keeps track of most recent, average, sum, and count of a metric.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@torch.no_grad()
def bip_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor:
"""Generate the attention mask for causal decoding"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
).to(device=device)
return mask
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience.
Source from: https://github.com/Bjarten/early-stopping-pytorch
"""
def __init__(self, patience=2, verbose=False, delta=0, trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
self._patience = patience
self._verbose = verbose
self._counter = 0
self._best_score = None
self._early_stop = False
self._val_loss_min = np.Inf
self._delta = delta
self._trace_func = trace_func
def __call__(self, val_loss):
score = -val_loss
save_checkpoint = True
if self._best_score is None:
self._best_score = score
save_checkpoint = True
if self._verbose:
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
self._trace_func(verb)
self._val_loss_min = val_loss
elif score < self._best_score + self._delta:
self._counter += 1
self._trace_func(
f"EarlyStopping counter: {self._counter} out of {self._patience}"
)
if self._counter >= self._patience:
self._early_stop = True
save_checkpoint = False
else:
self._best_score = score
save_checkpoint = True
self._counter = 0
if self._verbose:
verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})."
self._trace_func(verb)
self._val_loss_min = val_loss
return save_checkpoint
def print_dict(m: dict):
r"""
Print dict elements in separate lines sorted by keys
"""
if len(m) == 0:
return
# Check if the key is a stringified integer
first_key = next(iter(m))
is_numeric = isinstance(first_key, str) and first_key.isnumeric()
if is_numeric:
keys = sorted([int(k) for k in m.keys()])
else:
keys = sorted([k for k in m.keys()])
for k in keys:
if is_numeric:
v = m[str(k)]
else:
v = m[k]
print("{}: {}".format(k, v))
def print_list(lst: list):
r"""
Print list elements in separate lines
"""
for i, elm in enumerate(lst):
if isinstance(elm, list):
print("{}: ({}) - {}".format(i, len(elm), elm))
else:
print("{}: {}".format(i, elm))
@@ -0,0 +1,175 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import numpy as np
import docling_ibm_models.tableformer.settings as s
LOG_LEVEL = logging.INFO
class MyWelford:
r"""
Running computation of the sample mean and sample variance using Welford's algorithm
"""
def __init__(self):
self._i = 0 # Running index
self._m = 0 # Running mean
self._s = 0 # (n - 1) * variance
def reset(self):
r"""
Reset the object
"""
self._i = 0
self._m = 0
self._s = 0
def add(self, xi):
r"""
Invoke add each time a new sample arrives
Inputs:
xi: The next sample of data
"""
self._i += 1
old_m = self._m
self._m = self._m + (xi - self._m) / self._i
self._s = self._s + (xi - self._m) * (xi - old_m)
def results(self):
r"""
Get the computed mean, variance and standard deviation up to now
Outputs:
m: Sample mean
v: Sample variance
std: Sample standard deviation
"""
if self._i <= 1:
return None, None, None
# v = self._s / (self._i - 1) # Sample variance
v = self._s / (self._i) # Population variance
std = np.sqrt(v)
return self._m, v, std
class MyWelfordImg(MyWelford):
r"""
Welford algorithm to calculate running mean and sample variance for images
"""
def __init__(self):
super(MyWelfordImg, self).__init__()
def add(self, img):
r"""
Input:
img: An image numpy array (channel, width, height). The only requirement is to have the
channels as the first dimension and have 3 dimensions in total
"""
channels = img.shape[0]
flat_dim = img.shape[1] * img.shape[2]
img_r = img.reshape(channels, flat_dim)
for i in range(flat_dim):
super(MyWelfordImg, self).add(img_r[:, i])
class ChanVarianceImg:
r"""
Chan's algorithm to compute a running variance with support of sub-samples
In this implementation each sub-sample is an images
Math for the original paper:
https://github.ibm.com/nli/variance_formulae
"""
def __init__(self):
r""" """
self._first = True
# Size of the calculated dataset
self._n = 0
# Sum of the samples for the 3 image channels
self._t = 0
# Sum of the square differences of the deviations of the samples from the mean
self._s = 0
def add(self, img):
r"""
Add the provided image to the computation of the dataset statistics
Input:
img: An image numpy array (channel, width, height). The only requirement is to have the
channels as the first dimension and have 3 dimensions in total
"""
ch = img.shape[0]
n = img.shape[1] * img.shape[2]
img = img.reshape(ch, n)
img_t = img.sum(axis=1)
img_t_v = img_t.reshape(ch, 1)
diff = (img - (img_t_v / n)) ** 2
img_s = diff.sum(axis=1)
if not self._first:
c = (self._n / (n * (self._n + n))) * (
((n / self._n) * self._t - img_t) ** 2
)
self._s += img_s + c
self._t += img_t
else:
self._s = img_s
self._t = img_t
self._first = False
self._n += n
def results(self):
r"""
Get the computed statistics
Output:
mean: Mean for the complete dataset
var: Population variance for the complete dataset
std: Population standard deviation for the complete dataset
"""
mean = list(self._t / self._n)
var = list(self._s / self._n) # Population variance
std = list(np.sqrt(var))
return mean, var, std
def reset(self):
r"""
Reset the object to start over again
"""
self._n = 0
self._t = 0
self._s = 0
self._first = True
if __name__ == "__main__":
logger = s.get_custom_logger("variance", LOG_LEVEL)
n = 50000
channels = 3
width = 448
height = 448
my = ChanVarianceImg()
# Generate random images
for i in range(n):
logger.info(i)
img = 255 * np.random.rand(channels, width, height)
my.add(img)
# Calculate the statistics
m, v, std = my.results()
assert m.shape == (3,), "Wrong mean dimension"
assert v.shape == (3,), "Wrong variance dimension"
assert std.shape == (3,), "Wrong std dimension"
Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Generated
+2787
View File
File diff suppressed because it is too large Load Diff
+66
View File
@@ -0,0 +1,66 @@
[tool.poetry]
name = "docling-ibm-models"
version = "0.2.0"
description = "This package contains the AI models used by the Docling PDF conversion package"
authors = ["Nikos Livathinos <nli@zurich.ibm.com>", "Maxim Lysak <mly@zurich.ibm.com>", "Ahmed Nassar <ahn@zurich.ibm.com>", "Christoph Auer <cau@zurich.ibm.com>", "Michele Dolfi <dol@zurich.ibm.com>", "Peter Staar <taa@zurich.ibm.com>"]
license = "MIT"
readme = "README.md"
keywords= ["docling", "convert", "document", "pdf", "layout model", "segmentation", "table structure", "table former"]
classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3"
]
packages = [
{ include = "docling_ibm_models" },
]
[tool.poetry.dependencies]
python = "^3.11"
torch = "2.2.2"
torchvision = "0.17.2"
onnxruntime = "^1.16.2"
numpy = "^1.24.4"
lxml = "^4.9.1"
jsonlines = "^3.1.0"
Pillow = "^10.0.0"
tqdm = "^4.64.0"
apted = "^1.0.3"
Distance = "^0.1.3"
mean_average_precision = "^2021.4.26.0"
opencv-python-headless = { version = "^4.9.0.80", markers = 'sys_platform=="linux"'}
opencv-python = { version = "^4.9.0.80", markers = 'sys_platform!="linux"'}
[tool.poetry.dev-dependencies]
black = {extras = ["jupyter"], version = "^24.4.2"}
pytest = "^7.2.2"
pre-commit = "^3.7.1"
mypy = "^1.10.1"
isort = "^5.10.1"
python-semantic-release = "^7.32.2"
flake8 = "^6.0.0"
pyproject-flake8 = "^6.0.0"
pytest-xdist = "^3.3.1"
pytest-flake8 = "^1.1.0"
types-requests = "^2.31.0.2"
flake8-pyproject = "^1.2.3"
pylint = "^2.17.5"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = ["py311"]
include = '\.pyi?$'
[tool.isort]
profile = "black"
line_length = 88
py_version=311
View File
+72
View File
@@ -0,0 +1,72 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import docling_ibm_models.tableformer.common as c
test_config_a = {
"base_dir": "./tests/test_data/",
"curr_dir": "./tests/test_data/test_common/",
"data_top_dir": "./tests/test_data/",
"dataset": {
"name": ["PhysRevB"],
"limit": 10,
"split": {"test": 0.2, "train": 0.5, "evaluate": 0.3},
},
"features": {
"name": "Data2Features03b",
"parameters": {
"normalize_features": True,
"normalize_features_method": "Z-Score",
},
},
}
test_config_b = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 30}}
test_config_c = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 302}}
test_config_d = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 303}}
def test_safe_get_parameters():
val = c.safe_get_parameter(None, None, 10)
assert val == 10, "Failed with null objects"
index_path = ["features", "parameters", "normalize_features_method"]
val = c.safe_get_parameter(test_config_a, index_path, None)
assert val == "Z-Score", "Cannot find existing parameter"
index_path = ["features", "parameters", "wrong"]
val = c.safe_get_parameter(test_config_a, index_path, "hello")
assert val == "hello", "Default value should be here"
index_path = ["features", "wrong", "normalize_features_method"]
val = c.safe_get_parameter(test_config_a, index_path, 10)
assert val == 10, "Default value should be here"
index_path = ["model", "parameters", "normalize_features_method"]
val = c.safe_get_parameter(test_config_a, index_path, "hello")
assert val == "hello", "Default value should be here"
# Test exception throwing
exRaised = False
try:
index_path = ["missing"]
val = c.safe_get_parameter(test_config_a, index_path, required=True)
except ValueError:
exRaised = True
assert exRaised, "Exception should had been raised here"
def test_config_validation():
configs = [test_config_b, test_config_c, test_config_d]
for i, config in enumerate(configs):
try:
val = c.validate_config(config)
if i == 0 or i == 1:
assert val, "Valid configuration didn't pass the validation test"
except AssertionError:
assert i == 2, "Configuration validation error"
@@ -0,0 +1 @@
Put model check: "otslp_all_fast_clean.check" in this directory
Binary file not shown.

After

Width:  |  Height:  |  Size: 389 KiB

File diff suppressed because one or more lines are too long
Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

File diff suppressed because one or more lines are too long
+88
View File
@@ -0,0 +1,88 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import os
import numpy as np
import pytest
from PIL import Image
import docling_ibm_models.layoutmodel.layout_predictor as lp
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
@pytest.fixture(scope="module")
def init() -> dict:
r"""
Initialize the testing environment
"""
init = {
"artifact_path": "tests/test_data/model_artifacts/",
"num_threads": 1,
"test_imgs": [
"tests/test_data/samples/ADS.2007.page_123.png",
],
"info1": {
"onnx_file": os.path.join(
"tests/test_data/model_artifacts/", lp.MODEL_CHECKPOINT_FN
),
"intra_op_num_threads": 2,
"providers": ["CPUExecutionProvider"],
"use_cpu_only": True,
"image_size": 640,
"threshold": 0.6,
},
"info2": {
"onnx_file": os.path.join(
"tests/test_data/model_artifacts/", lp.MODEL_CHECKPOINT_FN
),
"intra_op_num_threads": 1,
"providers": ["CPUExecutionProvider"],
"use_cpu_only": True,
"image_size": 640,
"threshold": 0.6,
},
"pred_bboxes": 9,
}
return init
def test_layoutpredictor(init: dict):
r"""
Unit test for the LayoutPredictor
"""
# Initialize LayoutPredictor with envvars
os.environ["USE_CPU_ONLY"] = ""
os.environ["OMP_NUM_THREADS"] = "2"
lpredictor = LayoutPredictor(init["artifact_path"])
assert init["info1"] == lpredictor.info()
# Initialize LayoutPredictor with optional parameters
lpredictor = LayoutPredictor(
init["artifact_path"], num_threads=init["num_threads"], use_cpu_only=True
)
assert init["info2"] == lpredictor.info()
# Unsupported input image
is_exception = False
try:
for pred in lpredictor.predict("wrong"):
pass
except TypeError:
is_exception = True
assert is_exception
# Predict on the test image
for img_fn in init["test_imgs"]:
with Image.open(img_fn) as img:
# Load images as PIL objects
for i, pred in enumerate(lpredictor.predict(img)):
print("PIL pred: {}".format(pred))
assert i + 1 == init["pred_bboxes"]
# Load images as numpy arrays
np_arr = np.asarray(img)
for i, pred in enumerate(lpredictor.predict(np_arr)):
print("numpy pred: {}".format(pred))
assert i + 1 == init["pred_bboxes"]
+627
View File
@@ -0,0 +1,627 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import glob
import json
import os
from pathlib import Path
import cv2
from PIL import Image, ImageDraw
import docling_ibm_models.tableformer.data_management.tf_predictor as tf_predictor
from docling_ibm_models.tableformer.data_management.tf_predictor import \
TFPredictor
"""
- Implements TF predictor to accept the input format from IOCR, e.g.
"./tests/test_data/samples/tf_table_example_0.json" (trivial table crop)
- Shape the output format like GTE would: e.g.
"tests/test_data/samples/tf_gte_output_2.json" (Note: full form image)
"""
docling_api_data = {
"table_jsons": [
"./tests/test_data/samples/ADS.2007.page_123.png_iocr.parse_format.json",
"./tests/test_data/samples/PHM.2013.page_30.png_iocr.parse_format.json",
],
"png_images": [
"./tests/test_data/samples/ADS.2007.page_123.png",
"./tests/test_data/samples/PHM.2013.page_30.png",
],
"table_bboxes": [
[[178, 748, 1061, 976], [177, 1163, 1062, 1329]],
[[100, 186, 1135, 525]],
],
}
test_config = {
"dataset": {
"type": "TF_prepared",
"name": "TF",
"raw_data_dir": "./tests/test_data/model_artifacts/",
"load_cells": True,
"bbox_format": "5plet",
"resized_image": 448,
"keep_AR": False,
"up_scaling_enabled": True,
"down_scaling_enabled": True,
"padding_mode": "null",
"padding_color": [0, 0, 0],
"image_normalization": {
"state": True,
"mean": [0.94247851, 0.94254675, 0.94292611],
"std": [0.17910956, 0.17940403, 0.17931663],
},
"color_jitter": True,
"rand_crop": True,
"rand_pad": True,
"image_grayscale": False,
},
"model": {
"type": "TableModel04_rs",
"name": "14_128_256_4_true",
"save_dir": "./tests/test_data/model_artifacts/",
"backbone": "resnet18",
"enc_image_size": 28,
"tag_embed_dim": 16,
"hidden_dim": 512,
"tag_decoder_dim": 512,
"bbox_embed_dim": 256,
"tag_attention_dim": 256,
"bbox_attention_dim": 512,
"enc_layers": 4, # 6
"dec_layers": 2, # 6
"nheads": 8,
"dropout": 0.1,
"bbox_classes": 2,
},
"train": {
"save_periodicity": 1,
"disable_cuda": False,
"epochs": 23,
"batch_size": 50,
"clip_gradient": 0.1,
"clip_max_norm": 0.1,
"bbox": True,
"validation": False,
},
"predict": {
"max_steps": 1024,
"beam_size": 5,
"bbox": True,
"predict_dir": "./tests/test_data/samples",
"pdf_cell_iou_thres": 0.05,
"padding": False,
"padding_size": 50,
"disable_post_process": False,
"profiling": False,
"device_mode": "auto",
},
"dataset_wordmap": {
"word_map_tag": {
"<pad>": 0,
"<unk>": 1,
"<start>": 2,
"<end>": 3,
"ecel": 4,
"fcel": 5,
"lcel": 6,
"ucel": 7,
"xcel": 8,
"nl": 9,
"ched": 10,
"rhed": 11,
"srow": 12,
},
"word_map_cell": {
" ": 13,
"!": 179,
'"': 126,
"#": 101,
"$": 119,
"%": 18,
"&": 114,
"'": 108,
"(": 29,
")": 32,
"*": 26,
"+": 97,
",": 71,
"-": 63,
".": 34,
"/": 66,
"0": 33,
"1": 36,
"2": 43,
"3": 41,
"4": 45,
"5": 17,
"6": 37,
"7": 35,
"8": 40,
"9": 16,
":": 88,
";": 92,
"<": 73,
"</b>": 9,
"</i>": 23,
"</overline>": 219,
"</strike>": 233,
"</sub>": 94,
"</sup>": 77,
"</underline>": 151,
"<b>": 1,
"<end>": 280,
"<i>": 21,
"<overline>": 218,
"<pad>": 0,
"<start>": 279,
"<strike>": 232,
"<sub>": 93,
"<sup>": 75,
"<underline>": 150,
"<unk>": 278,
"=": 99,
">": 39,
"?": 96,
"@": 125,
"A": 27,
"B": 86,
"C": 19,
"D": 57,
"E": 64,
"F": 47,
"G": 44,
"H": 10,
"I": 20,
"J": 80,
"K": 81,
"L": 52,
"M": 46,
"N": 69,
"O": 65,
"P": 62,
"Q": 59,
"R": 60,
"S": 58,
"T": 48,
"U": 55,
"V": 2,
"W": 83,
"X": 104,
"Y": 89,
"Z": 113,
"[": 70,
"\\": 165,
"]": 72,
"^": 132,
"_": 84,
"`": 196,
"a": 3,
"b": 6,
"c": 54,
"d": 12,
"e": 8,
"f": 50,
"g": 28,
"h": 56,
"i": 5,
"j": 82,
"k": 95,
"l": 7,
"m": 30,
"n": 31,
"o": 15,
"p": 22,
"q": 67,
"r": 4,
"s": 51,
"t": 14,
"u": 25,
"v": 24,
"w": 53,
"x": 61,
"y": 49,
"z": 11,
"{": 158,
"|": 139,
"}": 159,
"~": 147,
"\u00a2": 203,
"\u00a3": 162,
"\u00a4": 220,
"\u00a5": 176,
"\u00a7": 142,
"\u00a9": 268,
"\u00ab": 239,
"\u00ad": 275,
"\u00ae": 130,
"\u00b0": 100,
"\u00b1": 79,
"\u00b6": 171,
"\u00b7": 137,
"\u00bb": 240,
"\u00d7": 118,
"\u00d8": 192,
"\u00df": 197,
"\u00e6": 261,
"\u00f7": 225,
"\u00f8": 163,
"\u0131": 242,
"\u0142": 267,
"\u01c2": 211,
"\u025b": 223,
"\u02b9": 248,
"\u02c2": 195,
"\u02c3": 208,
"\u02c6": 253,
"\u0300": 209,
"\u0301": 131,
"\u0302": 138,
"\u0303": 156,
"\u0304": 152,
"\u0306": 222,
"\u0307": 247,
"\u0308": 103,
"\u030a": 102,
"\u030c": 254,
"\u0327": 155,
"\u0328": 269,
"\u0338": 170,
"\u0391": 173,
"\u0392": 169,
"\u0393": 180,
"\u0394": 85,
"\u0398": 243,
"\u0399": 271,
"\u039b": 272,
"\u03a0": 213,
"\u03a3": 185,
"\u03a6": 148,
"\u03a7": 212,
"\u03a8": 141,
"\u03a9": 161,
"\u03b1": 90,
"\u03b2": 107,
"\u03b3": 110,
"\u03b4": 153,
"\u03b5": 166,
"\u03b6": 178,
"\u03b7": 146,
"\u03b8": 186,
"\u03b9": 229,
"\u03ba": 164,
"\u03bb": 91,
"\u03bc": 78,
"\u03bd": 230,
"\u03be": 244,
"\u03c0": 127,
"\u03c1": 149,
"\u03c3": 116,
"\u03c4": 198,
"\u03c5": 189,
"\u03c6": 140,
"\u03c7": 124,
"\u03c8": 216,
"\u03c9": 167,
"\u0410": 273,
"\u0421": 194,
"\u115f": 217,
"\u200b": 265,
"\u2010": 117,
"\u2012": 135,
"\u2013": 42,
"\u2014": 106,
"\u2015": 228,
"\u2016": 259,
"\u2018": 123,
"\u2019": 121,
"\u201c": 87,
"\u201d": 115,
"\u201e": 245,
"\u2020": 109,
"\u2021": 129,
"\u2022": 128,
"\u2028": 190,
"\u2030": 154,
"\u2032": 68,
"\u203b": 224,
"\u2044": 188,
"\u204e": 199,
"\u2061": 200,
"\u20ac": 184,
"\u2190": 202,
"\u2191": 112,
"\u2192": 120,
"\u2193": 111,
"\u2194": 183,
"\u21d1": 266,
"\u21d2": 264,
"\u21d3": 255,
"\u2205": 215,
"\u2206": 175,
"\u2208": 262,
"\u2211": 160,
"\u2212": 76,
"\u2216": 206,
"\u2217": 105,
"\u2218": 246,
"\u2219": 236,
"\u221a": 187,
"\u221e": 207,
"\u2223": 260,
"\u2225": 193,
"\u2227": 182,
"\u2229": 256,
"\u222b": 258,
"\u223c": 98,
"\u2248": 210,
"\u2264": 38,
"\u2265": 74,
"\u2266": 214,
"\u2267": 181,
"\u2295": 263,
"\u22c5": 174,
"\u22c6": 191,
"\u22ee": 277,
"\u22ef": 270,
"\u2500": 205,
"\u2551": 231,
"\u25a0": 250,
"\u25a1": 177,
"\u25aa": 145,
"\u25b2": 136,
"\u25b3": 143,
"\u25bc": 251,
"\u25c6": 226,
"\u25ca": 235,
"\u25cb": 227,
"\u25cf": 172,
"\u25e6": 274,
"\u2605": 204,
"\u2606": 144,
"\u2640": 133,
"\u2642": 134,
"\u2663": 252,
"\u2666": 157,
"\u266f": 221,
"\u2713": 122,
"\u2714": 249,
"\u2717": 201,
"\u2794": 168,
"\u27a2": 276,
"\u2a7d": 234,
"\u2a7e": 241,
"\u3008": 237,
"\u3009": 238,
"\ufeff": 257,
},
},
}
# ==================================================================================================
configs = [test_config]
def combine_checkpoint(save_dir):
r"""
Check if the checkpoint file is present as one part or 2 splits.
Combine parts into one file if needed
Parameters
----------
save_dir : string
The directory to check for checkpoint files or splits of it
Returns
-------
int
0: The full checkpoint file already exists, no combine was needed
1: The splits were found, a combine has been done
-1: No full checkpoint and no splits exist. Error
"""
# Check if the full file already exists
full_file_pattern = os.path.join(save_dir, "*.check")
candidate = glob.glob(full_file_pattern)
if len(candidate) == 1:
print(
"combine_checkpoint: The whole checkpoint file was found: {}".format(
candidate[0]
)
)
return 0
# Check for splits
splits_pattern = os.path.join(save_dir, "*.check.a[a-z]")
splits = glob.glob(splits_pattern)
splits.sort()
if splits is None or len(splits) == 0:
print(
"combine_checkpoint: Both the full checkpoint and the splits are missing. Error"
)
return -1
# Combine splits
full_fn = splits[0].rpartition(".check")[0] + ".check"
with open(full_fn, "wb") as f_out:
for split_fn in splits:
with open(split_fn, "rb") as f_split:
print("combine_checkpoint: read split: {}".format(split_fn))
f_out.write(f_split.read())
print("combine_checkpoint: combine splits as: {}".format(full_fn))
return 1
def test_tf_predictor():
r"""
Test the TFPredictor
"""
viz = True
# Load the docling_api_data
iocr_pages = []
for table_json_fn, png_image_fn, table_bboxes_b in zip(
docling_api_data["table_jsons"],
docling_api_data["png_images"],
docling_api_data["table_bboxes"],
):
with open(table_json_fn, "r") as fp:
iocr_page_raw = json.load(fp)
iocr_page = iocr_page_raw["pages"][0]
iocr_page["image"] = cv2.imread(png_image_fn)
# page_image = cv2.imread(png_image_fn)
iocr_page["png_image_fn"] = png_image_fn
iocr_page["table_bboxes"] = table_bboxes_b
iocr_pages.append(iocr_page)
# Loop over the test configs
for test_config in configs:
# Check if the checkpoint file should be combined
assert (
combine_checkpoint(test_config["model"]["save_dir"]) >= 0
), "Model checkpoint is missing"
# Loop over the iocr_pages
predictor = TFPredictor(test_config)
for iocr_page in iocr_pages:
# Prepare "Predict" parameters
# iw = iocr_page["width"]
# ih = iocr_page["height"]
# table_bboxes = [[0, 0, iw, ih]] # just one table per page in our examples
table_bboxes = iocr_page["table_bboxes"]
# for t, table_bbox in enumerate(table_bboxes):
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
png_img_bfn = os.path.basename(iocr_page["png_image_fn"])
print("Predicting image: {}".format(png_img_bfn))
# Run prediction, post-processing, and cell matching
# PARAMETERS:
# iocr_page - json received from iocr, augmented with iocr_page["image"]
# table_bboxes - list of detected bboxes on page: [[x1, y1, x2, y2], [...]...]
# do_matching - Boolean, when True - will match with text cells provided,
# when False - returns original cell prediction BBOXes in the same format
# OUTPUT:
# List of dicts per table: [{"tf_responses":[...], "predict_details": {}}]
multi_tf_output = predictor.multi_table_predict(
iocr_page, table_bboxes, False
)
# Test output for validity, create visualizations...
for t, tf_output in enumerate(multi_tf_output):
tf_responses = tf_output["tf_responses"]
predict_details = tf_output["predict_details"]
assert tf_responses is not None, "Empty prediction response"
assert isinstance(
tf_responses, list
), " Wrong response type. It should be a list"
img = Image.open(iocr_page["png_image_fn"])
img1 = ImageDraw.Draw(img)
xt0 = table_bboxes[t][0]
yt0 = table_bboxes[t][1]
xt1 = table_bboxes[t][2]
yt1 = table_bboxes[t][3]
img1.rectangle(((xt0, yt0), (xt1, yt1)), outline="pink", width=5)
if viz:
# Visualize original OCR words:
for iocr_word in iocr_page["tokens"]:
xi0 = iocr_word["bbox"]["l"]
yi0 = iocr_word["bbox"]["t"]
xi1 = iocr_word["bbox"]["r"]
yi1 = iocr_word["bbox"]["b"]
img1.rectangle(((xi0, yi0), (xi1, yi1)), outline="gray")
# Visualize original docling_ibm_models.tableformer predictions:
for predicted_bbox in predict_details["prediction_bboxes_page"]:
xp0 = predicted_bbox[0] - 2
yp0 = predicted_bbox[1] - 2
xp1 = predicted_bbox[2] + 2
yp1 = predicted_bbox[3] + 2
img1.rectangle(((xp0, yp0), (xp1, yp1)), outline="green")
# Check the structure of the list items
for i, response in enumerate(tf_responses):
assert (
"bbox" in response
), "bbox field is missing from response: " + str(i)
assert (
"text_cell_bboxes" in response
), "text_cell_bboxes is missing: " + str(i)
assert (
"row_span" in response
), "row_span is missing from resp: " + str(i)
assert (
"col_span" in response
), "col_span is missing from response: " + str(i)
# print("*********** column_header: {}".format(response["column_header"]))
if viz:
# Visualization:
for text_cell in response["text_cell_bboxes"]:
xc0 = text_cell["l"]
yc0 = text_cell["b"]
xc1 = text_cell["r"]
yc1 = text_cell["t"]
img1.rectangle(((xc0, yc0), (xc1, yc1)), outline="red")
x0 = response["bbox"]["l"] - 6
y0 = response["bbox"]["t"] - 6
x1 = response["bbox"]["r"] + 6
y1 = response["bbox"]["b"] + 6
if response["column_header"]:
img1.rectangle(
((x0, y0), (x1, y1)), outline="blue", width=5
)
elif response["row_header"]:
img1.rectangle(
((x0, y0), (x1, y1)), outline="magenta", width=5
)
elif response["row_section"]:
img1.rectangle(
((x0, y0), (x1, y1)), outline="brown", width=5
)
else:
img1.rectangle(
((x0, y0), (x1, y1)), outline="blue", width=1
)
if viz:
viz_root = "./tests/test_data/viz/"
Path(viz_root).mkdir(parents=True, exist_ok=True)
png_img_bfn1 = png_img_bfn.replace(".png", "." + str(t) + ".png")
viz_fn = os.path.join(viz_root, png_img_bfn1)
img.save(viz_fn)
# assert False
def test_device_mode():
r"""
Test the "predict.device_mode" parameter
"""
mini_configs = [
{"predict": {}},
{"predict": {"device_mode": "cpu"}},
{"predict": {"device_mode": "cuda"}},
{"predict": {"device_mode": "gpu"}},
{"predict": {"device_mode": "wrong"}},
]
for i, config in enumerate(mini_configs):
device = tf_predictor.decide_device(config)
assert device in ["cpu", "cuda:0"], "Irrelevant device has been returned"
if i == 0:
assert device == "cpu", "By default the 'cpu' device should be used"
elif i == 1:
assert device == "cpu", "An explicit 'cpu' device was given"
elif i == 2 or i == 3:
assert device == "cuda:0", "Cuda or gpu should become 'cuda:0'"
else:
assert (
device == "cpu"
), "A fall-back to 'cpu' should happen in case of error"