mirror of
https://github.com/docling-project/docling-ibm-models.git
synced 2026-05-17 13:10:52 +00:00
feat: TableFormer v2 (#149)
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Co-authored-by: Ahmed Nassar AHN@zurich.ibm.com <AHN@zurich.ibm.com> Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
committed by
GitHub
parent
78deb4be82
commit
b99c955aaf
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
-4
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import numbers
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
from __future__ import division
|
||||
|
||||
import collections
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import copy
|
||||
import logging
|
||||
from itertools import groupby
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import time
|
||||
from collections import deque
|
||||
from statistics import mean, median
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from docling_ibm_models.tableformer_v2.model import (
|
||||
TableFormerV2,
|
||||
TableFormerV2Config,
|
||||
TableFormerV2Output,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TableFormerV2",
|
||||
"TableFormerV2Config",
|
||||
"TableFormerV2Output",
|
||||
]
|
||||
@@ -0,0 +1,929 @@
|
||||
r"""
|
||||
TableFormerV2:
|
||||
This module provides a self-contained implementation of the TableFormerV2
|
||||
model for table structure recognition. It uses a lightweight architecture optimized for
|
||||
CPU inference / GPU batch inference while maintaining high accuracy.
|
||||
|
||||
Architecture overview:
|
||||
- Image encoder: EfficientNetV2-S backbone with Squeeze-and-Excitation
|
||||
- Spatial mixer: Depthwise separable convolutions (no self-attention in encoder)
|
||||
- Decoder: Cache-aware Transformer decoder with cross-attention to image features
|
||||
- Bbox head: Multi-layer attention decoder for cell bounding box prediction
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import efficientnet_v2_s
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
|
||||
|
||||
# =============================================================================
|
||||
# Custom Output Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableFormerV2Output(ModelOutput):
|
||||
r"""
|
||||
Output class for TableFormerV2 inference.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
logits : torch.Tensor, optional
|
||||
Token prediction logits of shape (B, L, vocab_size)
|
||||
hidden_states : torch.Tensor, optional
|
||||
Decoder hidden states of shape (B, L, D)
|
||||
predicted_bboxes : torch.Tensor, optional
|
||||
Predicted bounding boxes in xyxy format [0,1], shape (N, 4)
|
||||
past_key_values : tuple, optional
|
||||
Cached key-value pairs for autoregressive generation
|
||||
"""
|
||||
|
||||
logits: Optional[torch.Tensor] = None
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
predicted_bboxes: Optional[torch.Tensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Building Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
r"""
|
||||
Squeeze-and-Excitation block for channel-wise attention.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, reduction: int = 16):
|
||||
super().__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class DepthwiseSeparableBlock(nn.Module):
|
||||
r"""
|
||||
Depthwise separable convolution block for lightweight spatial mixing.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, expansion: float = 1.0):
|
||||
super().__init__()
|
||||
hidden = int(channels * expansion)
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channels, hidden, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(hidden),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden),
|
||||
nn.Conv2d(hidden, channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feature Pyramid Network (FPN)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SimpleFPN(nn.Module):
|
||||
r"""
|
||||
Simple Feature Pyramid Network to fuse multi-scale features from EfficientNet.
|
||||
|
||||
EfficientNetV2-S feature map sizes at 448x448 input:
|
||||
- Stage 2: 112x112, 48 channels (fine details: grid lines, small text)
|
||||
- Stage 4: 28x28, 128 channels (medium structures)
|
||||
- Final: 14x14, 1280 channels (after conv head)
|
||||
|
||||
This FPN fuses stages 2, 4, and final to capture multi-scale information.
|
||||
"""
|
||||
|
||||
def __init__(self, out_channels: int = 1280):
|
||||
super().__init__()
|
||||
# EfficientNetV2-S actual channel sizes: idx2=48, idx4=128, final=1280
|
||||
self.in_channels = [48, 128, 1280] # stages 2, 4, final
|
||||
|
||||
# Lateral connections (1x1 conv to match channels)
|
||||
self.lateral_conv_s2 = nn.Conv2d(48, out_channels, kernel_size=1)
|
||||
self.lateral_conv_s4 = nn.Conv2d(128, out_channels, kernel_size=1)
|
||||
# Final already has out_channels
|
||||
|
||||
# Smooth convolutions after upsampling
|
||||
self.smooth_s2 = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
)
|
||||
self.smooth_s4 = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
# Final fusion
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(out_channels * 3, out_channels, kernel_size=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, feat_s2: torch.Tensor, feat_s4: torch.Tensor, feat_final: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fuse multi-scale features.
|
||||
|
||||
Args:
|
||||
feat_s2: (B, 48, H/4, W/4) - stage 2 features
|
||||
feat_s4: (B, 160, H/16, W/16) - stage 4 features
|
||||
feat_final: (B, 1280, H/32, W/32) - final features
|
||||
|
||||
Returns:
|
||||
fused: (B, 1280, H/32, W/32) - fused features at final resolution
|
||||
"""
|
||||
target_size = feat_final.shape[2:] # (H/32, W/32)
|
||||
|
||||
# Lateral connections
|
||||
p2 = self.lateral_conv_s2(feat_s2)
|
||||
p4 = self.lateral_conv_s4(feat_s4)
|
||||
p_final = feat_final
|
||||
|
||||
# Resize all to final resolution
|
||||
p2 = nn.functional.interpolate(
|
||||
p2, size=target_size, mode="bilinear", align_corners=False
|
||||
)
|
||||
p4 = nn.functional.interpolate(
|
||||
p4, size=target_size, mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
# Smooth
|
||||
p2 = self.smooth_s2(p2)
|
||||
p4 = self.smooth_s4(p4)
|
||||
|
||||
# Concatenate and fuse
|
||||
fused = torch.cat([p2, p4, p_final], dim=1)
|
||||
fused = self.fusion(fused)
|
||||
|
||||
return fused
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Transformer Decoder Components
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CachedTransformerDecoderLayer(nn.Module):
|
||||
r"""
|
||||
Cache-aware Transformer decoder layer for efficient autoregressive generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_model : int
|
||||
Model dimension
|
||||
nhead : int
|
||||
Number of attention heads
|
||||
dim_feedforward : int
|
||||
Feedforward network hidden dimension
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, nhead: int, dim_feedforward: int):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.head_dim = d_model // nhead
|
||||
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(0.0)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def _sa_block(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
key_padding_mask: Optional[torch.Tensor],
|
||||
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
k = v = x
|
||||
if past_kv is not None:
|
||||
past_k, past_v = past_kv
|
||||
k = torch.cat([past_k, k], dim=1)
|
||||
v = torch.cat([past_v, v], dim=1)
|
||||
|
||||
attn_output, _ = self.self_attn(
|
||||
x,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
|
||||
present = (k, v) if use_cache else None
|
||||
return attn_output, present
|
||||
|
||||
def _ca_block(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mem: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
key_padding_mask: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
attn_output, _ = self.multihead_attn(
|
||||
x,
|
||||
mem,
|
||||
mem,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
tgt_mask: Optional[torch.Tensor] = None,
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
sa_out, present_kv = self._sa_block(
|
||||
tgt,
|
||||
tgt_mask,
|
||||
tgt_key_padding_mask,
|
||||
past_kv=past_key_value,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
tgt = self.norm1(tgt + self.dropout(sa_out))
|
||||
ca_out = self._ca_block(tgt, memory, memory_mask, memory_key_padding_mask)
|
||||
tgt = self.norm2(tgt + self.dropout(ca_out))
|
||||
ff = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = self.norm3(tgt + self.dropout(ff))
|
||||
return tgt, present_kv
|
||||
|
||||
|
||||
class CachedTransformerDecoder(nn.Module):
|
||||
r"""
|
||||
Stack of cache-aware Transformer decoder layers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_model : int
|
||||
Model dimension
|
||||
nhead : int
|
||||
Number of attention heads
|
||||
dim_feedforward : int
|
||||
Feedforward network hidden dimension
|
||||
num_layers : int
|
||||
Number of decoder layers
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, nhead: int, dim_feedforward: int, num_layers: int):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CachedTransformerDecoderLayer(d_model, nhead, dim_feedforward)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
tgt_mask: Optional[torch.Tensor] = None,
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]]:
|
||||
next_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
|
||||
[] if use_cache else None
|
||||
)
|
||||
for i, layer in enumerate(self.layers):
|
||||
past = past_key_values[i] if past_key_values is not None else None
|
||||
tgt, present = layer(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
past_key_value=past,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
if use_cache and present is not None and next_past is not None:
|
||||
next_past.append(present)
|
||||
return tgt, (tuple(next_past) if use_cache and next_past is not None else None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bounding Box Decoder Components
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
|
||||
cx, cy, w, h = cxcywh.unbind(-1)
|
||||
x1 = (cx - 0.5 * w).clamp(0.0, 1.0)
|
||||
y1 = (cy - 0.5 * h).clamp(0.0, 1.0)
|
||||
x2 = (cx + 0.5 * w).clamp(0.0, 1.0)
|
||||
y2 = (cy + 0.5 * h).clamp(0.0, 1.0)
|
||||
return torch.stack([x1, y1, x2, y2], dim=-1)
|
||||
|
||||
|
||||
class BboxDecoderLayer(nn.Module):
|
||||
r"""
|
||||
Single decoder layer for bounding box prediction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
embed_dim, num_heads, dropout=dropout, batch_first=True
|
||||
)
|
||||
self.self_attn_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.cross_attn = nn.MultiheadAttention(
|
||||
embed_dim, num_heads, dropout=dropout, batch_first=True
|
||||
)
|
||||
self.cross_attn_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(embed_dim, ff_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ff_dim, embed_dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
self.ffn_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
batch_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x_unsq = x.unsqueeze(0)
|
||||
|
||||
if batch_mask is not None:
|
||||
attn_mask = ~batch_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
sa_out, _ = self.self_attn(x_unsq, x_unsq, x_unsq, attn_mask=attn_mask)
|
||||
x = self.self_attn_norm(x + sa_out.squeeze(0))
|
||||
|
||||
x_unsq = x.unsqueeze(1)
|
||||
ca_out, _ = self.cross_attn(x_unsq, memory, memory)
|
||||
x = self.cross_attn_norm(x + ca_out.squeeze(1))
|
||||
|
||||
x = self.ffn_norm(x + self.ffn(x))
|
||||
return x
|
||||
|
||||
|
||||
class BboxHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 2,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
ff_dim = embed_dim * 4
|
||||
|
||||
self.input_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.input_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.kv_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.kv_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
BboxDecoderLayer(embed_dim, num_heads, ff_dim, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.bbox_mlp = nn.Sequential(
|
||||
nn.Linear(embed_dim, embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(embed_dim, embed_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(embed_dim // 2, 4),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cell_embeddings: torch.Tensor,
|
||||
encoder_hidden: torch.Tensor,
|
||||
cell_batch_indices: torch.Tensor,
|
||||
spatial_size: Optional[Tuple[int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Predict bounding boxes for cell embeddings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cell_embeddings : torch.Tensor
|
||||
Decoder hidden states at cell positions, shape (N, D)
|
||||
encoder_hidden : torch.Tensor
|
||||
Encoder outputs (image features), shape (B, S, D)
|
||||
cell_batch_indices : torch.Tensor
|
||||
Batch index for each cell, shape (N,)
|
||||
spatial_size : tuple, optional
|
||||
Unused, kept for API compatibility
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Predicted bboxes in xyxy format [0, 1], shape (N, 4)
|
||||
"""
|
||||
if cell_embeddings.numel() == 0:
|
||||
return cell_embeddings.new_empty(0, 4)
|
||||
|
||||
batch_mask = cell_batch_indices.unsqueeze(0) == cell_batch_indices.unsqueeze(1)
|
||||
encoder_for_cells = encoder_hidden[cell_batch_indices]
|
||||
|
||||
x = self.input_norm(self.input_proj(cell_embeddings))
|
||||
memory = self.kv_norm(self.kv_proj(encoder_for_cells))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, memory, batch_mask)
|
||||
|
||||
bbox_cxcywh = torch.sigmoid(self.bbox_mlp(x))
|
||||
bbox_xyxy = _cxcywh_to_xyxy(bbox_cxcywh)
|
||||
return bbox_xyxy
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TableFormerV2Config(PretrainedConfig):
|
||||
model_type = "TableFormerV2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 512,
|
||||
num_heads: int = 8,
|
||||
ff_dim: int = 2048,
|
||||
num_decoder_layers: int = 2,
|
||||
vocab_size: int = 13,
|
||||
conv_mixer_expansion: float = 1.0,
|
||||
data_cells: Optional[List[int]] = None,
|
||||
use_fpn: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.ff_dim = ff_dim
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.conv_mixer_expansion = conv_mixer_expansion
|
||||
self.data_cells = data_cells or []
|
||||
self.use_fpn = use_fpn
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Model Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TableFormerV2(PreTrainedModel):
|
||||
r"""
|
||||
TableFormerV2: CPU-optimized model for table structure recognition (inference only).
|
||||
|
||||
This model uses:
|
||||
- EfficientNetV2-S backbone for image encoding
|
||||
- Depthwise separable convolutions instead of Transformer encoder
|
||||
- Cache-aware Transformer decoder for token generation
|
||||
- Attention-based bbox head for cell localization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TableFormerV2Config
|
||||
Model configuration
|
||||
"""
|
||||
|
||||
config_class = TableFormerV2Config # type: ignore[assignment]
|
||||
|
||||
def __init__(self, config: TableFormerV2Config):
|
||||
super().__init__(config)
|
||||
|
||||
# Vision encoder
|
||||
self.feature_extractor = efficientnet_v2_s()
|
||||
self.se_module = SqueezeExcitation(in_channels=1280)
|
||||
self.conv_mixer = DepthwiseSeparableBlock(
|
||||
1280, expansion=config.conv_mixer_expansion
|
||||
)
|
||||
self.feature_to_embedding = nn.Linear(1280, config.embed_dim)
|
||||
|
||||
# Optional FPN for multi-scale features
|
||||
self.use_fpn = getattr(config, "use_fpn", False)
|
||||
if self.use_fpn:
|
||||
self.fpn = SimpleFPN(out_channels=1280)
|
||||
|
||||
# embeddings
|
||||
self.input_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
||||
self.positional_encoding = nn.Parameter(torch.randn(1, 512, config.embed_dim))
|
||||
|
||||
# decoder with caching
|
||||
self.transformer_decoder = CachedTransformerDecoder(
|
||||
d_model=config.embed_dim,
|
||||
nhead=config.num_heads,
|
||||
dim_feedforward=config.ff_dim,
|
||||
num_layers=config.num_decoder_layers,
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.output_projection = nn.Linear(config.embed_dim, config.vocab_size)
|
||||
self.bbox_head = BboxHead(config.embed_dim, config.num_heads)
|
||||
|
||||
self.data_cells = config.data_cells
|
||||
|
||||
def _is_profiling_enabled(self) -> bool:
|
||||
r"""
|
||||
Check if profiling is enabled by checking if AggProfiler has cycles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if profiling is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
profiler = AggProfiler()
|
||||
# Check if profiler has cycles (profiling has been started)
|
||||
return len(profiler._cycles) > 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _positional_encoding(
|
||||
self, batch_size: int, seq_len: int, offset: int = 0
|
||||
) -> torch.Tensor:
|
||||
pos_enc_size = self.positional_encoding.size(1)
|
||||
total_len = offset + seq_len
|
||||
if total_len <= pos_enc_size:
|
||||
return self.positional_encoding[:, offset : offset + seq_len, :].expand(
|
||||
batch_size, seq_len, -1
|
||||
)
|
||||
num_repeats = (total_len + pos_enc_size - 1) // pos_enc_size
|
||||
repeated = self.positional_encoding.repeat(1, num_repeats, 1)
|
||||
return repeated[:, offset : offset + seq_len, :].expand(batch_size, seq_len, -1)
|
||||
|
||||
def encode_images(self, images: torch.Tensor) -> dict:
|
||||
prof_enabled = self._is_profiling_enabled()
|
||||
if prof_enabled:
|
||||
AggProfiler().begin("model_encoder", prof_enabled)
|
||||
|
||||
if self.use_fpn:
|
||||
# Extract multi-scale features for FPN
|
||||
# EfficientNetV2-S stages: 0-1 (stem), 2, 3, 4, 5, 6, 7 (head)
|
||||
x = images
|
||||
feat_s2 = None
|
||||
feat_s4 = None
|
||||
|
||||
for idx, layer in enumerate(self.feature_extractor.features):
|
||||
x = layer(x)
|
||||
if idx == 2: # Stage 2: 48 channels
|
||||
feat_s2 = x
|
||||
elif idx == 4: # Stage 4: 160 channels
|
||||
feat_s4 = x
|
||||
|
||||
feat_final = x # Final: 1280 channels
|
||||
|
||||
# Fuse with FPN
|
||||
features = self.fpn(feat_s2, feat_s4, feat_final)
|
||||
features = self.se_module(features)
|
||||
features = self.conv_mixer(features)
|
||||
else:
|
||||
# Original single-scale path
|
||||
features = self.feature_extractor.features(images)
|
||||
features = self.se_module(features)
|
||||
features = self.conv_mixer(features)
|
||||
|
||||
B, C, H, W = features.shape
|
||||
features = features.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
||||
encoded = self.feature_to_embedding(features)
|
||||
|
||||
if prof_enabled:
|
||||
AggProfiler().end("model_encoder", prof_enabled)
|
||||
|
||||
return {"last_hidden_state": encoded, "spatial_size": (H, W)}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[dict] = None,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TableFormerV2Output:
|
||||
r"""
|
||||
Forward pass for inference.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : torch.Tensor, optional
|
||||
Input images of shape (B, 3, H, W)
|
||||
input_ids : torch.Tensor
|
||||
Input token IDs of shape (B, L)
|
||||
attention_mask : torch.Tensor, optional
|
||||
Attention mask of shape (B, L)
|
||||
encoder_outputs : dict, optional
|
||||
Pre-computed encoder outputs
|
||||
past_key_values : tuple, optional
|
||||
Cached key-values for autoregressive decoding
|
||||
use_cache : bool, optional
|
||||
Whether to use KV caching (default: True)
|
||||
return_dict : bool
|
||||
Whether to return a ModelOutput (default: True)
|
||||
|
||||
Returns
|
||||
-------
|
||||
TableFormerV2Output
|
||||
Model outputs including logits and predicted bboxes
|
||||
"""
|
||||
use_cache = True if use_cache is None else use_cache
|
||||
|
||||
if encoder_outputs is None:
|
||||
if images is None:
|
||||
raise ValueError("Either images or encoder_outputs must be provided")
|
||||
encoder_outputs = self.encode_images(images)
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids must be provided")
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
past_length = 0
|
||||
if (
|
||||
past_key_values is not None
|
||||
and len(past_key_values) > 0
|
||||
and past_key_values[0] is not None
|
||||
):
|
||||
past_length = past_key_values[0][0].shape[1]
|
||||
|
||||
tgt = self.input_embedding(input_ids) + self._positional_encoding(
|
||||
batch_size, seq_len, offset=past_length
|
||||
)
|
||||
|
||||
if past_length > 0:
|
||||
causal_mask = None
|
||||
else:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, device=input_ids.device)
|
||||
).T
|
||||
causal_mask = causal_mask.masked_fill(
|
||||
causal_mask == 0, float("-inf")
|
||||
).masked_fill(causal_mask == 1, 0.0)
|
||||
|
||||
prof_enabled = self._is_profiling_enabled()
|
||||
|
||||
if prof_enabled:
|
||||
AggProfiler().begin("model_tag_transformer_decoder", prof_enabled)
|
||||
decoded, present_kv = self.transformer_decoder(
|
||||
tgt=tgt,
|
||||
memory=encoder_outputs["last_hidden_state"],
|
||||
tgt_mask=causal_mask,
|
||||
tgt_key_padding_mask=None,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
if prof_enabled:
|
||||
AggProfiler().end("model_tag_transformer_decoder", prof_enabled)
|
||||
|
||||
if prof_enabled:
|
||||
AggProfiler().begin("model_tag_transformer_fc", prof_enabled)
|
||||
logits = self.output_projection(decoded)
|
||||
if prof_enabled:
|
||||
AggProfiler().end("model_tag_transformer_fc", prof_enabled)
|
||||
|
||||
# Identify cell positions and predict bboxes
|
||||
cell_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
for cell_id in self.data_cells:
|
||||
cell_mask |= input_ids == cell_id
|
||||
|
||||
cell_positions = torch.nonzero(cell_mask, as_tuple=False)
|
||||
cell_embeddings = decoded[cell_mask]
|
||||
cell_batch_indices = (
|
||||
cell_positions[:, 0]
|
||||
if cell_positions.numel() > 0
|
||||
else cell_positions.new_empty(0)
|
||||
)
|
||||
|
||||
pred_bboxes = self.bbox_head(
|
||||
cell_embeddings,
|
||||
encoder_outputs["last_hidden_state"],
|
||||
cell_batch_indices,
|
||||
spatial_size=encoder_outputs.get("spatial_size", None),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (logits, decoded, pred_bboxes, present_kv) # type: ignore[return-value]
|
||||
|
||||
return TableFormerV2Output(
|
||||
logits=logits,
|
||||
hidden_states=decoded,
|
||||
predicted_bboxes=pred_bboxes,
|
||||
past_key_values=present_kv,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
past: Optional[Tuple] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
tokenizer,
|
||||
max_length: int = 512,
|
||||
generation_config=None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
r"""
|
||||
Autoregressive generation with bounding box prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : torch.Tensor
|
||||
Input images of shape (B, 3, H, W)
|
||||
tokenizer : PreTrainedTokenizer
|
||||
Tokenizer with bos_token_id and eos_token_id
|
||||
max_length : int
|
||||
Maximum sequence length (default: 512)
|
||||
generation_config : GenerationConfig, optional
|
||||
HuggingFace generation configuration
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary containing:
|
||||
- generated_ids: (B, L) token IDs
|
||||
- predicted_bboxes: (B, num_cells, 4) bboxes in xyxy [0, 1]
|
||||
"""
|
||||
if generation_config is not None and hasattr(generation_config, "max_length"):
|
||||
max_length = generation_config.max_length
|
||||
|
||||
prof_enabled = self._is_profiling_enabled()
|
||||
if prof_enabled:
|
||||
AggProfiler().begin("predict_total", prof_enabled)
|
||||
|
||||
self.eval() # type: ignore[attr-defined]
|
||||
with torch.no_grad():
|
||||
# Image encoding (profiling handled inside encode_images)
|
||||
encoder_outputs = self.encode_images(images)
|
||||
batch_size = images.size(0)
|
||||
device = images.device
|
||||
|
||||
generated_ids = torch.full(
|
||||
(batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=device
|
||||
)
|
||||
current_input = generated_ids
|
||||
past_key_values = None
|
||||
|
||||
# Autoregressive generation loop
|
||||
for step in range(max_length):
|
||||
outputs = self.forward(
|
||||
input_ids=current_input,
|
||||
attention_mask=None,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if outputs.logits is None:
|
||||
raise ValueError("Model forward pass returned None logits")
|
||||
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
||||
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
||||
past_key_values = outputs.past_key_values
|
||||
current_input = next_token
|
||||
|
||||
if torch.all(next_token == tokenizer.eos_token_id):
|
||||
break
|
||||
|
||||
# Final forward pass to get hidden states for bbox prediction
|
||||
final_outputs = self.forward(
|
||||
input_ids=generated_ids,
|
||||
attention_mask=torch.ones_like(generated_ids),
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
hidden_states = final_outputs.hidden_states
|
||||
|
||||
# Find cell positions and predict bboxes
|
||||
pred_bboxes = None
|
||||
max_cells = 0
|
||||
cell_positions_per_batch = []
|
||||
|
||||
for b in range(batch_size):
|
||||
seq = generated_ids[b]
|
||||
positions = []
|
||||
for pos, tok in enumerate(seq.tolist()):
|
||||
if tok in self.data_cells:
|
||||
positions.append(pos)
|
||||
cell_positions_per_batch.append(positions)
|
||||
max_cells = max(max_cells, len(positions))
|
||||
|
||||
if max_cells > 0:
|
||||
pred_bboxes = torch.zeros(batch_size, max_cells, 4, device=device)
|
||||
spatial_size = encoder_outputs.get("spatial_size", None)
|
||||
|
||||
for b in range(batch_size):
|
||||
positions = cell_positions_per_batch[b]
|
||||
if positions:
|
||||
if hidden_states is None:
|
||||
raise ValueError(
|
||||
"Model forward pass returned None hidden_states"
|
||||
)
|
||||
cell_embs = hidden_states[b, positions, :]
|
||||
batch_indices = torch.zeros(
|
||||
len(positions), dtype=torch.long, device=device
|
||||
)
|
||||
enc_out = encoder_outputs["last_hidden_state"][b : b + 1]
|
||||
|
||||
if prof_enabled:
|
||||
AggProfiler().begin("model_bbox_decoder", prof_enabled)
|
||||
bboxes = self.bbox_head(
|
||||
cell_embs, enc_out, batch_indices, spatial_size=spatial_size
|
||||
)
|
||||
if prof_enabled:
|
||||
AggProfiler().end("model_bbox_decoder", prof_enabled)
|
||||
pred_bboxes[b, : len(bboxes)] = bboxes
|
||||
|
||||
if prof_enabled:
|
||||
AggProfiler().end("predict_total", prof_enabled)
|
||||
|
||||
return {
|
||||
"generated_ids": generated_ids,
|
||||
"predicted_bboxes": pred_bboxes,
|
||||
}
|
||||
|
||||
|
||||
AutoConfig.register("TableFormerV2", TableFormerV2Config)
|
||||
AutoModel.register(TableFormerV2Config, TableFormerV2)
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -44,6 +40,7 @@ def init() -> dict:
|
||||
return init
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Legacy code-formula predictor test is disabled (unused path)")
|
||||
def test_code_formula_predictor(init: dict):
|
||||
r"""
|
||||
Unit test for the CodeFormulaPredictor
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
@@ -0,0 +1,529 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image, ImageDraw
|
||||
from transformers import AutoTokenizer
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from docling_ibm_models.tableformer_v2 import TableFormerV2
|
||||
from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
|
||||
|
||||
TABLEFORMER_V2_REPO_ID = "docling-project/TableFormerV2"
|
||||
TABLEFORMER_V2_REVISION = "v0.1.0"
|
||||
|
||||
def load_tokenizer(path: str, revision: str | None = None):
|
||||
"""Load tokenizer from local path or HuggingFace repo."""
|
||||
import os
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Check if it's a local path with tokenizer.json
|
||||
tokenizer_file = os.path.join(path, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_file):
|
||||
# Load directly from tokenizer.json
|
||||
backend_tokenizer = Tokenizer.from_file(tokenizer_file)
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=backend_tokenizer,
|
||||
bos_token="<start>",
|
||||
eos_token="<end>",
|
||||
pad_token="<pad>",
|
||||
unk_token="[UNK]",
|
||||
)
|
||||
return tokenizer
|
||||
else:
|
||||
# Try loading from HuggingFace repo using AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, revision=revision)
|
||||
# Set special tokens if not already set
|
||||
if tokenizer.bos_token is None:
|
||||
tokenizer.bos_token = "<start>"
|
||||
if tokenizer.eos_token is None:
|
||||
tokenizer.eos_token = "<end>"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = "<pad>"
|
||||
return tokenizer
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load tokenizer from {path}. "
|
||||
f"Tried local tokenizer.json and HuggingFace AutoTokenizer. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Test data matching test_tf_predictor.py structure exactly
|
||||
test_data = {
|
||||
"png_images": [
|
||||
"tests/test_data/samples/ADS.2007.page_123.png",
|
||||
"tests/test_data/samples/PHM.2013.page_30.png",
|
||||
"tests/test_data/samples/empty_iocr.png",
|
||||
],
|
||||
"table_bboxes": [
|
||||
[[178, 748, 1061, 976], [177, 1163, 1062, 1329]], # ADS.2007 has 2 tables
|
||||
[[100, 186, 1135, 525]], # PHM.2013 has 1 table
|
||||
[[178, 748, 1061, 976], [177, 1163, 1062, 1329]], # empty_iocr has 2 tables
|
||||
],
|
||||
}
|
||||
|
||||
# Test configuration
|
||||
test_config = {
|
||||
"num_threads": 1,
|
||||
"image_size": 448,
|
||||
"max_length": 512,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init() -> dict:
|
||||
r"""
|
||||
Initialize the testing environment
|
||||
"""
|
||||
init = test_config.copy()
|
||||
init["test_data"] = test_data
|
||||
|
||||
# Download model and tokenizer from HuggingFace Hub
|
||||
artifact_path = snapshot_download(
|
||||
repo_id=TABLEFORMER_V2_REPO_ID,
|
||||
revision=TABLEFORMER_V2_REVISION,
|
||||
)
|
||||
|
||||
# Use local checkpoint with tokenizer files
|
||||
init["artifact_path"] = artifact_path
|
||||
init["artifact_revision"] = TABLEFORMER_V2_REVISION
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def test_tableformer_v2_model_loading(init: dict):
|
||||
r"""
|
||||
Test that the TableFormerV2 model loads correctly
|
||||
"""
|
||||
device = "cpu"
|
||||
|
||||
# Load the model
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Check model attributes
|
||||
assert hasattr(model, "config"), "Model missing config attribute"
|
||||
assert hasattr(model, "generate"), "Model missing generate method"
|
||||
assert hasattr(model, "forward"), "Model missing forward method"
|
||||
assert hasattr(model, "encode_images"), "Model missing encode_images method"
|
||||
assert hasattr(model, "bbox_head"), "Model missing bbox_head"
|
||||
|
||||
# Check config values
|
||||
config = model.config
|
||||
assert config.model_type == "TableFormerV2", "Wrong model type"
|
||||
assert config.vocab_size > 0, "Invalid vocab size"
|
||||
assert config.embed_dim > 0, "Invalid embed_dim"
|
||||
assert len(config.data_cells) > 0, "data_cells should not be empty"
|
||||
|
||||
|
||||
def test_tableformer_v2_tokenizer_loading(init: dict):
|
||||
r"""
|
||||
Test that the tokenizer loads correctly
|
||||
"""
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
|
||||
# Check tokenizer attributes
|
||||
assert tokenizer.bos_token_id is not None, "Missing bos_token_id"
|
||||
assert tokenizer.eos_token_id is not None, "Missing eos_token_id"
|
||||
assert tokenizer.pad_token_id is not None, "Missing pad_token_id"
|
||||
|
||||
# Check OTSL tokens exist
|
||||
vocab = tokenizer.get_vocab()
|
||||
expected_tokens = ["<fcel>", "<ecel>", "<nl>", "<start>", "<end>", "<pad>"]
|
||||
for token in expected_tokens:
|
||||
assert token in vocab, f"Missing expected token: {token}"
|
||||
|
||||
|
||||
def test_tableformer_v2_image_encoding(init: dict):
|
||||
r"""
|
||||
Test image encoding functionality on cropped table
|
||||
"""
|
||||
device = "cpu"
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Prepare transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((init["image_size"], init["image_size"])),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Test with first table crop from first image
|
||||
img_fn = init["test_data"]["png_images"][0]
|
||||
table_bbox = init["test_data"]["table_bboxes"][0][0] # First table
|
||||
|
||||
with Image.open(img_fn) as img:
|
||||
img_rgb = img.convert("RGB")
|
||||
# Crop to table region
|
||||
x1, y1, x2, y2 = table_bbox
|
||||
table_crop = img_rgb.crop((x1, y1, x2, y2))
|
||||
image_tensor = transform(table_crop).unsqueeze(0).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode_images(image_tensor)
|
||||
|
||||
assert "last_hidden_state" in encoder_outputs, "Missing last_hidden_state"
|
||||
assert "spatial_size" in encoder_outputs, "Missing spatial_size"
|
||||
|
||||
hidden_state = encoder_outputs["last_hidden_state"]
|
||||
assert hidden_state.dim() == 3, "Hidden state should be 3D (B, S, D)"
|
||||
assert hidden_state.size(0) == 1, "Batch size should be 1"
|
||||
assert hidden_state.size(2) == model.config.embed_dim, "Wrong embed dim"
|
||||
|
||||
|
||||
def test_tableformer_v2_forward_pass(init: dict):
|
||||
r"""
|
||||
Test forward pass functionality on cropped table
|
||||
"""
|
||||
device = "cpu"
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Prepare transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((init["image_size"], init["image_size"])),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Test with first table crop
|
||||
img_fn = init["test_data"]["png_images"][0]
|
||||
table_bbox = init["test_data"]["table_bboxes"][0][0]
|
||||
|
||||
with Image.open(img_fn) as img:
|
||||
img_rgb = img.convert("RGB")
|
||||
x1, y1, x2, y2 = table_bbox
|
||||
table_crop = img_rgb.crop((x1, y1, x2, y2))
|
||||
image_tensor = transform(table_crop).unsqueeze(0).to(device)
|
||||
|
||||
# Create dummy input_ids (start token)
|
||||
input_ids = torch.tensor([[tokenizer.bos_token_id]], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.forward(
|
||||
images=image_tensor,
|
||||
input_ids=input_ids,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
assert outputs.logits is not None, "Missing logits"
|
||||
assert outputs.hidden_states is not None, "Missing hidden_states"
|
||||
assert outputs.logits.size(-1) == model.config.vocab_size, "Wrong vocab size in logits"
|
||||
|
||||
|
||||
def test_tableformer_v2_predict(init: dict):
|
||||
r"""
|
||||
Test TableFormerV2 prediction on cropped table images.
|
||||
Matches the pattern from test_tf_predictor.py.
|
||||
Includes profiling similar to test_tf_predictor.py.
|
||||
"""
|
||||
device = "cpu"
|
||||
viz = True # Save visualizations
|
||||
enable_profiling = True
|
||||
|
||||
# Initialize profiler (cycles started per-table for fair comparison with V1)
|
||||
profiler = AggProfiler()
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Prepare transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((init["image_size"], init["image_size"])),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Process each page and its tables
|
||||
for img_idx, (img_fn, table_bboxes) in enumerate(zip(
|
||||
init["test_data"]["png_images"],
|
||||
init["test_data"]["table_bboxes"]
|
||||
)):
|
||||
print(f"\n{'>'*40}")
|
||||
img_basename = os.path.basename(img_fn)
|
||||
print(f"Processing image: {img_basename}")
|
||||
print(f"Number of tables: {len(table_bboxes)}")
|
||||
|
||||
with Image.open(img_fn) as page_img:
|
||||
page_rgb = page_img.convert("RGB")
|
||||
page_w, page_h = page_rgb.size
|
||||
|
||||
# Process each table on the page
|
||||
for t_idx, table_bbox in enumerate(table_bboxes):
|
||||
x1, y1, x2, y2 = table_bbox
|
||||
print(f"\n Table {t_idx}: bbox=[{x1}, {y1}, {x2}, {y2}]")
|
||||
|
||||
# Crop table from page
|
||||
table_crop = page_rgb.crop((x1, y1, x2, y2))
|
||||
crop_w, crop_h = table_crop.size
|
||||
print(f" Crop size: {crop_w}x{crop_h}")
|
||||
|
||||
# Transform for model
|
||||
image_tensor = transform(table_crop).unsqueeze(0).to(device)
|
||||
|
||||
# Start new profiling cycle for this table (matches V1 behavior)
|
||||
profiler.start_agg(enable=enable_profiling)
|
||||
|
||||
# Run prediction (profiling is handled inside model.generate)
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
images=image_tensor,
|
||||
tokenizer=tokenizer,
|
||||
max_length=init["max_length"]
|
||||
)
|
||||
|
||||
# Check output structure
|
||||
assert "generated_ids" in output, "Missing generated_ids"
|
||||
assert output["generated_ids"] is not None, "generated_ids is None"
|
||||
|
||||
generated_ids = output["generated_ids"]
|
||||
assert generated_ids.dim() == 2, "generated_ids should be 2D"
|
||||
|
||||
# Decode OTSL output
|
||||
decoded_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
print(f" Generated OTSL: {decoded_text[:80]}...")
|
||||
|
||||
# Check bounding boxes
|
||||
pred_bboxes = output.get("predicted_bboxes")
|
||||
num_cells = 0
|
||||
if pred_bboxes is not None and pred_bboxes.numel() > 0:
|
||||
# Handle (B, num_cells, 4) format
|
||||
if pred_bboxes.dim() == 3:
|
||||
pred_bboxes = pred_bboxes[0]
|
||||
# Filter zero-padded boxes
|
||||
valid_mask = pred_bboxes.sum(dim=-1) > 0
|
||||
pred_bboxes = pred_bboxes[valid_mask]
|
||||
num_cells = pred_bboxes.size(0)
|
||||
print(f" Predicted cells: {num_cells}")
|
||||
|
||||
# Validate bbox ranges
|
||||
if num_cells > 0:
|
||||
assert pred_bboxes.min() >= 0.0, "Bboxes should be >= 0"
|
||||
assert pred_bboxes.max() <= 1.0, "Bboxes should be <= 1"
|
||||
|
||||
# Visualization
|
||||
if viz and pred_bboxes is not None and num_cells > 0:
|
||||
viz_root = "./tests/test_data/viz/"
|
||||
Path(viz_root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Draw on the table crop
|
||||
draw_img = table_crop.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
|
||||
for i, bbox in enumerate(pred_bboxes.cpu().tolist()):
|
||||
# Unnormalize to crop coordinates
|
||||
bx1 = bbox[0] * crop_w
|
||||
by1 = bbox[1] * crop_h
|
||||
bx2 = bbox[2] * crop_w
|
||||
by2 = bbox[3] * crop_h
|
||||
draw.rectangle([bx1, by1, bx2, by2], outline="red", width=2)
|
||||
|
||||
# Also draw table bbox on full page for context
|
||||
draw.rectangle([0, 0, crop_w-1, crop_h-1], outline="blue", width=3)
|
||||
|
||||
viz_fn = os.path.join(
|
||||
viz_root,
|
||||
f"tableformer_v2_{img_basename.replace('.png', '')}_table{t_idx}.png"
|
||||
)
|
||||
draw_img.save(viz_fn)
|
||||
print(f" Saved visualization: {viz_fn}")
|
||||
|
||||
# Get and print profiling data
|
||||
if enable_profiling:
|
||||
profiling_data = profiler.get_data()
|
||||
print("\n" + "="*60)
|
||||
print("PROFILING DATA")
|
||||
print("="*60)
|
||||
print(json.dumps(profiling_data, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
def test_tableformer_v2_numpy_input(init: dict):
|
||||
r"""
|
||||
Test that model works with numpy array input (converted to tensor)
|
||||
"""
|
||||
device = "cpu"
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Prepare transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((init["image_size"], init["image_size"])),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
img_fn = init["test_data"]["png_images"][0]
|
||||
table_bbox = init["test_data"]["table_bboxes"][0][0]
|
||||
|
||||
with Image.open(img_fn) as img:
|
||||
img_rgb = img.convert("RGB")
|
||||
x1, y1, x2, y2 = table_bbox
|
||||
table_crop = img_rgb.crop((x1, y1, x2, y2))
|
||||
|
||||
# Convert to numpy array first, then back to PIL for transform
|
||||
np_arr = np.asarray(table_crop)
|
||||
img_from_np = Image.fromarray(np_arr)
|
||||
image_tensor = transform(img_from_np).unsqueeze(0).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
images=image_tensor,
|
||||
tokenizer=tokenizer,
|
||||
max_length=init["max_length"]
|
||||
)
|
||||
|
||||
assert "generated_ids" in output, "Missing generated_ids"
|
||||
assert output["generated_ids"] is not None, "generated_ids is None"
|
||||
|
||||
|
||||
def test_tableformer_v2_batch_inference(init: dict):
|
||||
r"""
|
||||
Test batch inference with multiple table crops.
|
||||
Runs on GPU if available, otherwise CPU.
|
||||
"""
|
||||
# Detect device: use GPU if available, otherwise CPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BATCH INFERENCE TEST - Device: {device}")
|
||||
print(f"{'='*60}")
|
||||
if device == "cuda":
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"CUDA Version: {torch.version.cuda}")
|
||||
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Prepare transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((init["image_size"], init["image_size"])),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Collect table crops from all pages
|
||||
image_tensors = []
|
||||
for img_fn, table_bboxes in zip(
|
||||
init["test_data"]["png_images"],
|
||||
init["test_data"]["table_bboxes"]
|
||||
):
|
||||
with Image.open(img_fn) as img:
|
||||
img_rgb = img.convert("RGB")
|
||||
# Take first table from each page
|
||||
x1, y1, x2, y2 = table_bboxes[0]
|
||||
table_crop = img_rgb.crop((x1, y1, x2, y2))
|
||||
image_tensors.append(transform(table_crop))
|
||||
|
||||
# Stack into batch
|
||||
batch_tensor = torch.stack(image_tensors).to(device)
|
||||
batch_size = batch_tensor.size(0)
|
||||
print(f"\nBatch size: {batch_size}")
|
||||
print(f"Input tensor shape: {batch_tensor.shape}")
|
||||
|
||||
# Warmup run (especially important for GPU)
|
||||
if device == "cuda":
|
||||
print("\nRunning warmup...")
|
||||
with torch.no_grad():
|
||||
_ = model.generate(
|
||||
images=batch_tensor,
|
||||
tokenizer=tokenizer,
|
||||
max_length=min(init["max_length"], 100) # Shorter for warmup
|
||||
)
|
||||
torch.cuda.synchronize() if device == "cuda" else None
|
||||
|
||||
# Actual inference with timing
|
||||
import time
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
images=batch_tensor,
|
||||
tokenizer=tokenizer,
|
||||
max_length=init["max_length"]
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
assert "generated_ids" in output, "Missing generated_ids"
|
||||
generated_ids = output["generated_ids"]
|
||||
assert generated_ids.size(0) == batch_size, f"Expected batch size {batch_size}"
|
||||
|
||||
print(f"\nBatch inference results ({batch_size} tables):")
|
||||
for i in range(batch_size):
|
||||
decoded = tokenizer.decode(generated_ids[i], skip_special_tokens=True)
|
||||
seq_len = generated_ids[i].size(0)
|
||||
print(f" Table {i}: seq_len={seq_len}, {decoded[:50]}...")
|
||||
|
||||
print(f"\nInference time: {inference_time:.3f} seconds")
|
||||
print(f"Time per table: {inference_time/batch_size:.3f} seconds")
|
||||
print(f"Throughput: {batch_size/inference_time:.2f} tables/second")
|
||||
|
||||
# Check bounding boxes
|
||||
pred_bboxes = output.get("predicted_bboxes")
|
||||
if pred_bboxes is not None and pred_bboxes.numel() > 0:
|
||||
print(f"\nBounding boxes shape: {pred_bboxes.shape}")
|
||||
if pred_bboxes.dim() == 3:
|
||||
for b in range(batch_size):
|
||||
valid_mask = pred_bboxes[b].sum(dim=-1) > 0
|
||||
num_cells = valid_mask.sum().item()
|
||||
print(f" Batch {b}: {num_cells} cells")
|
||||
|
||||
|
||||
def test_tableformer_v2_unsupported_input(init: dict):
|
||||
r"""
|
||||
Test that model raises appropriate errors for unsupported inputs
|
||||
"""
|
||||
device = "cpu"
|
||||
|
||||
model = TableFormerV2.from_pretrained(
|
||||
init["artifact_path"], revision=init["artifact_revision"]
|
||||
)
|
||||
tokenizer = load_tokenizer(init["artifact_path"], revision=init["artifact_revision"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Test with wrong tensor shape
|
||||
is_exception = False
|
||||
try:
|
||||
wrong_shape = torch.randn(1, 1, 224, 224).to(device) # Wrong channels
|
||||
with torch.no_grad():
|
||||
model.generate(images=wrong_shape, tokenizer=tokenizer)
|
||||
except Exception:
|
||||
is_exception = True
|
||||
assert is_exception, "Should raise exception for wrong input shape"
|
||||
@@ -1,7 +1,3 @@
|
||||
#
|
||||
# Copyright IBM Corp. 2024 - 2024
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
Reference in New Issue
Block a user