feat: Improved region extraction and filtering for doclingsdg_builder.py (#217)

* Improved region extraction and filtering for doclingsdg_builder.py

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Added table region selector by label in doclingsdg_builder

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Added skipping switch for 90 degree tables

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

---------

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
Maxim Lysak
2026-05-15 17:40:59 +02:00
committed by GitHub
parent 102bb119ea
commit 6efbb3f214
2 changed files with 667 additions and 84 deletions
@@ -21,7 +21,7 @@ from docling_core.types.doc import (
TableItem,
)
from docling_core.types.io import DocumentStream
from PIL import Image
from PIL import Image, ImageDraw
from pydantic import ValidationError
from tqdm import tqdm
@@ -45,6 +45,14 @@ _BBOX_OVERLAP_EPSILON = 1e-3
_ROTATION_ASPECT_THRESHOLD = 1.15
_COORD_TOLERANCE = 1.0
_ROW_LEFT_DUPLICATE_TOLERANCE = 0.5
_MIN_GRID_ROWS = 1
_MAX_GRID_ROWS = 40
_MIN_GRID_COLS = 1
_MAX_GRID_COLS = 20
_MIN_PAGE_IMAGE_DIM = 32
_MAX_PAGE_IMAGE_DIM = 4096
_REGION_OVERLAP_IOU_THRESHOLD = 0.99
_SKIP_ROTATED_90_TABLES = True
_TABLE_REGION_CATEGORY_IDS: Dict[str, int] = {
"table": 1,
@@ -57,6 +65,21 @@ _TABLE_REGION_CATEGORY_IDS: Dict[str, int] = {
"cell_merged": 8,
}
_TABLE_REGION_EXPORT_LABELS: Tuple[str, ...] = (
"table",
"row",
"column",
"cell_merged",
)
_TABLE_REGION_EXPORT_LABELS_SET = set(_TABLE_REGION_EXPORT_LABELS)
_TABLE_REGIONS_VIZ_SOURCE_DOCLING = "docling"
_TABLE_REGIONS_VIZ_SOURCE_REGIONS = "regions"
_TABLE_REGIONS_VIZ_SOURCES = {
_TABLE_REGIONS_VIZ_SOURCE_DOCLING,
_TABLE_REGIONS_VIZ_SOURCE_REGIONS,
}
class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
"""
@@ -76,6 +99,7 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
split: str = "test",
begin_index: int = 0,
end_index: int = -1,
table_regions_visualization_source: str = _TABLE_REGIONS_VIZ_SOURCE_REGIONS,
):
try:
parsed_modality = EvaluationModality(modality)
@@ -109,6 +133,31 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
self.modality = parsed_modality
self.must_retrieve = False
if len(_TABLE_REGION_EXPORT_LABELS_SET) != len(_TABLE_REGION_EXPORT_LABELS):
raise ValueError(
"Duplicate labels detected in _TABLE_REGION_EXPORT_LABELS."
)
unknown_export_labels = sorted(
label
for label in _TABLE_REGION_EXPORT_LABELS
if label not in _TABLE_REGION_CATEGORY_IDS
)
if unknown_export_labels:
raise ValueError(
"Unknown labels in _TABLE_REGION_EXPORT_LABELS: "
f"{unknown_export_labels}. Expected subset of "
f"{sorted(_TABLE_REGION_CATEGORY_IDS.keys())}"
)
source = str(table_regions_visualization_source).strip().lower()
if source not in _TABLE_REGIONS_VIZ_SOURCES:
raise ValueError(
"Unsupported table regions visualization source "
f"'{table_regions_visualization_source}'. Expected one of: "
f"{sorted(_TABLE_REGIONS_VIZ_SOURCES)}"
)
self.table_regions_visualization_source = source
@staticmethod
def _sort_by_page_suffix(path: Path) -> tuple[int, str]:
match = _PAGE_SUFFIX_PATTERN.search(path.stem)
@@ -321,6 +370,10 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
def _table_region_category_id(label: str) -> int:
return _TABLE_REGION_CATEGORY_IDS[label]
@staticmethod
def _should_export_table_region_label(label: str) -> bool:
return label in _TABLE_REGION_EXPORT_LABELS_SET
@staticmethod
def _safe_page_height(page: Optional[PageItem], page_image: Image.Image) -> float:
if (
@@ -422,6 +475,55 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
inter_h = min(rect_a[3], rect_b[3]) - max(rect_a[1], rect_b[1])
return inter_w > _BBOX_OVERLAP_EPSILON and inter_h > _BBOX_OVERLAP_EPSILON
@staticmethod
def _rect_iou(
rect_a: Tuple[float, float, float, float],
rect_b: Tuple[float, float, float, float],
) -> float:
inter_w = min(rect_a[2], rect_b[2]) - max(rect_a[0], rect_b[0])
inter_h = min(rect_a[3], rect_b[3]) - max(rect_a[1], rect_b[1])
if inter_w <= 0.0 or inter_h <= 0.0:
return 0.0
inter_area = inter_w * inter_h
area_a = max(0.0, rect_a[2] - rect_a[0]) * max(0.0, rect_a[3] - rect_a[1])
area_b = max(0.0, rect_b[2] - rect_b[0]) * max(0.0, rect_b[3] - rect_b[1])
union = area_a + area_b - inter_area
if union <= 0.0:
return 0.0
return inter_area / union
@classmethod
def _find_region_overlap_iou_issue(
cls,
*,
region_name: str,
regions: List[Tuple[int, Tuple[float, float, float, float]]],
threshold: float = _REGION_OVERLAP_IOU_THRESHOLD,
) -> Optional[str]:
for i in range(len(regions)):
region_id_a, rect_a = regions[i]
for j in range(i + 1, len(regions)):
region_id_b, rect_b = regions[j]
iou = cls._rect_iou(rect_a, rect_b)
if iou > threshold:
return (
f"Overlapping {region_name} regions with IoU>{threshold:.2f} "
f"(ids={region_id_a},{region_id_b}, iou={iou:.4f})."
)
return None
@staticmethod
def _overlap_region_type_from_reason(reason: str) -> Optional[str]:
normalized = reason
if normalized.startswith("fallback::"):
normalized = normalized[len("fallback::") :]
if normalized.startswith("Overlapping row regions with IoU>"):
return "row"
if normalized.startswith("Overlapping column regions with IoU>"):
return "column"
return None
@staticmethod
def _is_table_rotated_90(
row_rects: List[Tuple[float, float, float, float]],
@@ -453,12 +555,12 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
row_span: int,
col_span: int,
) -> str:
if is_column_header:
return "cell_column_header"
if is_row_header:
return "cell_row_header"
if is_row_section:
return "cell_section_header"
# if is_column_header:
# return "cell_column_header"
# if is_row_header:
# return "cell_row_header"
# if is_row_section:
# return "cell_section_header"
if row_span > 1 or col_span > 1:
return "cell_merged"
return "cell_single"
@@ -921,12 +1023,6 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
if not table_cells:
return {}, has_rotated_90, "Table has no cells."
row_rects_map: Dict[int, List[Tuple[float, float, float, float]]] = (
defaultdict(list)
)
col_rects_map: Dict[int, List[Tuple[float, float, float, float]]] = (
defaultdict(list)
)
occupancy: Dict[Tuple[int, int], int] = {}
cell_entries: List[Dict[str, Any]] = []
@@ -967,11 +1063,6 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
"Invalid cell bbox (out of image bounds).",
)
for row_id in range(row_start, row_end):
row_rects_map[row_id].append(cell_rect)
for col_id in range(col_start, col_end):
col_rects_map[col_id].append(cell_rect)
for row_id in range(row_start, row_end):
for col_id in range(col_start, col_end):
slot = (row_id, col_id)
@@ -1007,9 +1098,61 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
}
)
if not row_rects_map or not col_rects_map:
try:
row_bboxes = table_data.get_row_bounding_boxes(minimal=False)
col_bboxes = table_data.get_column_bounding_boxes(minimal=False)
except Exception as exc: # noqa: BLE001
return (
{},
has_rotated_90,
f"Failed to compute row/column regions from table data: {exc}",
)
row_regions: List[Tuple[int, Tuple[float, float, float, float]]] = []
for row_id, row_bbox in sorted(row_bboxes.items()):
row_rect = self._normalize_bbox_to_page_image(
row_bbox,
page_height=page_height,
page_image=page_image,
reject_out_of_bounds=True,
)
if row_rect is None:
return {}, has_rotated_90, "Invalid row bbox (out of image bounds)."
row_regions.append((int(row_id), row_rect))
col_regions: List[Tuple[int, Tuple[float, float, float, float]]] = []
for col_id, col_bbox in sorted(col_bboxes.items()):
col_rect = self._normalize_bbox_to_page_image(
col_bbox,
page_height=page_height,
page_image=page_image,
reject_out_of_bounds=True,
)
if col_rect is None:
return (
{},
has_rotated_90,
"Invalid column bbox (out of image bounds).",
)
col_regions.append((int(col_id), col_rect))
if not row_regions or not col_regions:
return {}, has_rotated_90, "Missing row/column regions."
row_overlap_issue = self._find_region_overlap_iou_issue(
region_name="row",
regions=row_regions,
)
if row_overlap_issue is not None:
return {}, has_rotated_90, row_overlap_issue
col_overlap_issue = self._find_region_overlap_iou_issue(
region_name="column",
regions=col_regions,
)
if col_overlap_issue is not None:
return {}, has_rotated_90, col_overlap_issue
active_cells: List[Dict[str, Any]] = []
for entry in sorted(cell_entries, key=lambda e: e["rect"][0]):
left = float(entry["rect"][0])
@@ -1049,39 +1192,40 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
active_cells.append(entry)
page_boxes = bboxes_by_page.setdefault(page_index, [])
page_boxes.append(
self._bbox_payload(
label="table",
category_id=self._table_region_category_id("table"),
rect=table_rect,
if self._should_export_table_region_label("table"):
page_boxes.append(
self._bbox_payload(
label="table",
category_id=self._table_region_category_id("table"),
rect=table_rect,
)
)
)
row_rects: List[Tuple[float, float, float, float]] = []
for row_id in sorted(row_rects_map):
row_rect = self._union_rectangles(row_rects_map[row_id])
for row_id, row_rect in row_regions:
row_rects.append(row_rect)
page_boxes.append(
self._bbox_payload(
label="row",
category_id=self._table_region_category_id("row"),
rect=row_rect,
extras={"row_id": row_id},
if self._should_export_table_region_label("row"):
page_boxes.append(
self._bbox_payload(
label="row",
category_id=self._table_region_category_id("row"),
rect=row_rect,
extras={"row_id": row_id},
)
)
)
col_rects: List[Tuple[float, float, float, float]] = []
for col_id in sorted(col_rects_map):
col_rect = self._union_rectangles(col_rects_map[col_id])
for col_id, col_rect in col_regions:
col_rects.append(col_rect)
page_boxes.append(
self._bbox_payload(
label="column",
category_id=self._table_region_category_id("column"),
rect=col_rect,
extras={"col_id": col_id},
if self._should_export_table_region_label("column"):
page_boxes.append(
self._bbox_payload(
label="column",
category_id=self._table_region_category_id("column"),
rect=col_rect,
extras={"col_id": col_id},
)
)
)
if self._is_table_rotated_90(row_rects=row_rects, col_rects=col_rects):
has_rotated_90 = True
@@ -1091,10 +1235,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
return {}, has_rotated_90, row_left_issue
for entry in cell_entries:
cell_label = str(entry["label"])
if not self._should_export_table_region_label(cell_label):
continue
page_boxes.append(
self._bbox_payload(
label=str(entry["label"]),
category_id=self._table_region_category_id(str(entry["label"])),
label=cell_label,
category_id=self._table_region_category_id(cell_label),
rect=entry["rect"],
extras={
"text": entry["text"] if entry["text"] is not None else "",
@@ -1165,6 +1312,146 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
with open(filename, "w", encoding="utf-8") as file_handle:
file_handle.write("\n".join(html_parts))
@staticmethod
def _region_viz_style(label: str) -> Tuple[Tuple[int, int, int, int], int]:
if label == "table":
return (220, 38, 38, 255), 3
if label == "row":
return (34, 197, 94, 220), 2
if label == "column":
return (37, 99, 235, 220), 2
if label == "cell_column_header":
return (0, 0, 255, 180), 2
if label == "cell_row_header":
return (255, 0, 0, 180), 2
if label == "cell_section_header":
return (255, 0, 255, 180), 2
if label == "cell_merged":
return (255, 0, 0, 180), 2
if label == "cell_single":
return (0, 0, 255, 180), 2
elif label.startswith("cell_"):
return (249, 115, 22, 180), 1
return (107, 114, 128, 180), 1
@staticmethod
def _rect_from_bbox_entry(
box: Dict[str, Any],
) -> Optional[Tuple[float, float, float, float]]:
ltrb_raw = box.get("ltrb")
if isinstance(ltrb_raw, list) and len(ltrb_raw) == 4:
try:
left, top, right, bottom = [float(v) for v in ltrb_raw]
if right > left and bottom > top:
return left, top, right, bottom
except (TypeError, ValueError):
pass
bbox_raw = box.get("bbox")
if isinstance(bbox_raw, list) and len(bbox_raw) == 4:
try:
left, top, width, height = [float(v) for v in bbox_raw]
right = left + width
bottom = top + height
if width > 0 and height > 0:
return left, top, right, bottom
except (TypeError, ValueError):
pass
return None
def _table_regions_visualization_from_regions(
self, record: DatasetRecord
) -> Dict[Optional[int], Image.Image]:
if not isinstance(record, DatasetRecordWithBBox):
raise ValueError(
"Region-based visualization requires DatasetRecordWithBBox."
)
page_visualizations: Dict[Optional[int], Image.Image] = {}
boxes_by_page = record.ground_truth_bbox_on_page_images
draw_order = {"table": 0, "row": 1, "column": 2}
for page_index, page_image in enumerate(record.ground_truth_page_images):
page_canvas = page_image.convert("RGB").copy()
draw = ImageDraw.Draw(page_canvas, mode="RGBA")
page_boxes = list(boxes_by_page.get(page_index, []))
page_boxes.sort(
key=lambda box: draw_order.get(str(box.get("label", "")), 3)
)
for box in page_boxes:
label = str(box.get("label", ""))
# if label not in ["row", "column"]:
# if label == "cell_section_header":
# if label == "column":
if label.startswith("cell_"):
rect = self._rect_from_bbox_entry(box)
if rect is None:
continue
color, width = self._region_viz_style(label)
fill_alpha = 35 if label in ("row", "column") else 0
fill = (color[0], color[1], color[2], fill_alpha)
draw.rectangle(
[(rect[0], rect[1]), (rect[2], rect[3])],
outline=color,
fill=fill if fill_alpha > 0 else None,
width=width,
)
page_visualizations[page_index] = page_canvas
if not page_visualizations:
page_visualizations[None] = get_missing_pageimg()
return page_visualizations
def _table_regions_visualization_from_docling(
self, record: DatasetRecord
) -> Dict[Optional[int], Image.Image]:
table_doc = insert_images_from_pil(
document=copy.deepcopy(record.ground_truth_doc),
pictures=record.ground_truth_pictures,
page_images=record.ground_truth_page_images,
)
table_visualizer = TableVisualizer(
params=TableVisualizer.Params(
show_cells=True,
show_rows=True,
show_cols=True,
minimal_row_bboxes=True,
minimal_col_bboxes=True,
)
)
return table_visualizer.get_visualization(doc=table_doc)
@staticmethod
def _count_row_column_regions(
bboxes_by_page: Dict[int, List[Dict[str, Any]]],
) -> Tuple[int, int]:
row_count = 0
col_count = 0
for page_boxes in bboxes_by_page.values():
for box in page_boxes:
label = str(box.get("label", ""))
if label == "row":
row_count += 1
elif label == "column":
col_count += 1
return row_count, col_count
@staticmethod
def _page_images_within_limits(page_images: List[Image.Image]) -> bool:
for image in page_images:
width, height = image.size
if not (
_MIN_PAGE_IMAGE_DIM <= int(width) <= _MAX_PAGE_IMAGE_DIM
and _MIN_PAGE_IMAGE_DIM <= int(height) <= _MAX_PAGE_IMAGE_DIM
):
return False
return True
def save_ground_truth_visualization(
self,
record: DatasetRecord,
@@ -1174,26 +1461,25 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
super().save_ground_truth_visualization(record, viz_path_split)
return
# Use TableVisualizer output for table_regions instead of layout visualizer.
table_doc = insert_images_from_pil(
document=copy.deepcopy(record.ground_truth_doc),
pictures=record.ground_truth_pictures,
page_images=record.ground_truth_page_images,
)
table_visualizer = TableVisualizer(
params=TableVisualizer.Params(
show_cells=True,
show_rows=False,
show_cols=False,
minimal_row_bboxes=False,
minimal_col_bboxes=False,
)
)
# For table_regions, allow visualization from either extracted regions
# or the Docling document table visualizer.
try:
page_visualizations = table_visualizer.get_visualization(doc=table_doc)
if (
self.table_regions_visualization_source
== _TABLE_REGIONS_VIZ_SOURCE_REGIONS
):
page_visualizations = self._table_regions_visualization_from_regions(
record
)
else:
page_visualizations = self._table_regions_visualization_from_docling(
record
)
except Exception as exc: # noqa: BLE001
_log.warning(
"TableVisualizer failed for %s: %s. Falling back to default visualization.",
"Table regions visualization (%s) failed for %s: %s. "
"Falling back to default visualization.",
self.table_regions_visualization_source,
record.doc_id,
exc,
)
@@ -1231,6 +1517,10 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
skipped_png_errors = 0
skipped_malformed = 0
skipped_no_table = 0
skipped_filtered = 0
skipped_row_overlap = 0
skipped_col_overlap = 0
filter_reason_counts: Dict[str, int] = defaultdict(int)
malformed_reason_counts: Dict[str, int] = defaultdict(int)
for json_path in tqdm(
@@ -1343,6 +1633,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
malformed_reason,
)
skipped_malformed += 1
overlap_region_type = self._overlap_region_type_from_reason(
malformed_reason
)
if overlap_region_type == "row":
skipped_row_overlap += 1
elif overlap_region_type == "column":
skipped_col_overlap += 1
if malformed_reason == "Document has no table items.":
skipped_no_table += 1
malformed_reason_counts[malformed_reason] = (
@@ -1418,6 +1715,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
fallback_reason,
)
skipped_malformed += 1
overlap_region_type = self._overlap_region_type_from_reason(
fallback_reason
)
if overlap_region_type == "row":
skipped_row_overlap += 1
elif overlap_region_type == "column":
skipped_col_overlap += 1
malformed_reason_counts[f"fallback::{fallback_reason}"] = (
malformed_reason_counts.get(
f"fallback::{fallback_reason}", 0
@@ -1433,6 +1737,16 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
)
if has_rotated_90:
if _SKIP_ROTATED_90_TABLES:
skipped_filtered += 1
filter_reason_counts["rotated_90_table"] = (
filter_reason_counts.get("rotated_90_table", 0) + 1
)
_log.warning(
"Skipping table sample %s: 90-degree rotated table detected.",
doc_id,
)
continue
tags.append("90_degree")
else:
ground_truth_bboxes = self._extract_top_level_bboxes(
@@ -1441,6 +1755,52 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
page_images=extracted_page_images,
)
if self.modality == EvaluationModality.TABLE_REGIONS:
row_count, col_count = self._count_row_column_regions(
ground_truth_bboxes
)
if not (_MIN_GRID_ROWS <= row_count <= _MAX_GRID_ROWS):
skipped_filtered += 1
filter_reason_counts[f"rows_out_of_range::{row_count}"] = (
filter_reason_counts.get(f"rows_out_of_range::{row_count}", 0)
+ 1
)
_log.warning(
"Skipping table sample %s: row region count %d is outside [%d, %d].",
doc_id,
row_count,
_MIN_GRID_ROWS,
_MAX_GRID_ROWS,
)
continue
if not (_MIN_GRID_COLS <= col_count <= _MAX_GRID_COLS):
skipped_filtered += 1
filter_reason_counts[f"cols_out_of_range::{col_count}"] = (
filter_reason_counts.get(f"cols_out_of_range::{col_count}", 0)
+ 1
)
_log.warning(
"Skipping table sample %s: column region count %d is outside [%d, %d].",
doc_id,
col_count,
_MIN_GRID_COLS,
_MAX_GRID_COLS,
)
continue
if not self._page_images_within_limits(extracted_page_images):
skipped_filtered += 1
filter_reason_counts["page_image_size_out_of_range"] = (
filter_reason_counts.get("page_image_size_out_of_range", 0) + 1
)
_log.warning(
"Skipping sample %s: page image size outside [%d, %d] px.",
doc_id,
_MIN_PAGE_IMAGE_DIM,
_MAX_PAGE_IMAGE_DIM,
)
continue
if len(png_files) == 1:
original_bytes = get_binary(png_files[0])
original_stream = DocumentStream(
@@ -1476,7 +1836,7 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
(
"DoclingSDG processing summary (modality=%s): "
"processed=%d, exported=%d, skipped=%d "
"(load_errors=%d, missing_png=%d, png_read_errors=%d, malformed=%d, no_table=%d)"
"(load_errors=%d, missing_png=%d, png_read_errors=%d, malformed=%d, no_table=%d, filtered=%d)"
),
self.modality.value,
processed_docs,
@@ -1487,9 +1847,22 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
skipped_png_errors,
skipped_malformed,
skipped_no_table,
skipped_filtered,
)
if malformed_reason_counts:
_log.info(
"DoclingSDG malformed reason counts: %s",
dict(sorted(malformed_reason_counts.items())),
)
if filter_reason_counts:
_log.info(
"DoclingSDG filtered reason counts: %s",
dict(sorted(filter_reason_counts.items())),
)
skipped_region_overlap = skipped_row_overlap + skipped_col_overlap
_log.info(
"DoclingSDG row/column-overlap skips: total=%d (row=%d, column=%d)",
skipped_region_overlap,
skipped_row_overlap,
skipped_col_overlap,
)
+231 -21
View File
@@ -9,7 +9,10 @@ from PIL import Image
from docling_eval.datamodels.dataset_record import DatasetRecordWithBBox
from docling_eval.datamodels.types import BenchMarkNames
from docling_eval.dataset_builders.doclingsdg_builder import DoclingSDGDatasetBuilder
from docling_eval.dataset_builders.doclingsdg_builder import (
_TABLE_REGION_EXPORT_LABELS,
DoclingSDGDatasetBuilder,
)
def _copy_json_png_pair(source_json: Path, target_dir: Path) -> None:
@@ -183,35 +186,48 @@ def test_doclingsdg_builder_table_regions_bbox_labels(tmp_path: Path):
restored = DatasetRecordWithBBox.model_validate(row)
labels = {box["label"] for box in restored.ground_truth_bbox_on_page_images[0]}
assert labels.issubset(set(_TABLE_REGION_EXPORT_LABELS))
assert "table" in labels
assert "row" in labels
assert "column" in labels
assert "cell_single" in labels
assert "cell_merged" in labels
if "cell_merged" in _TABLE_REGION_EXPORT_LABELS:
assert "cell_merged" in labels
def test_doclingsdg_builder_table_regions_adds_90_degree_tag(tmp_path: Path):
dataset_source = tmp_path / "doclingsdg_rot_source"
target = tmp_path / "doclingsdg_rot_target"
dataset_source.mkdir(parents=True)
rotated_source = Path("EXAMPLE_DOCLING_SDG_TABLES/rotated_90_deg")
sample_json = sorted(rotated_source.glob("*.json"))[0]
_copy_json_png_pair(sample_json, dataset_source)
rotated_jsons = sorted(rotated_source.glob("*.json"))
assert rotated_jsons
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
for idx, sample_json in enumerate(rotated_jsons):
dataset_source = tmp_path / f"doclingsdg_rot_source_{idx}"
target = tmp_path / f"doclingsdg_rot_target_{idx}"
dataset_source.mkdir(parents=True)
ds = load_dataset(
"parquet",
data_files={"test": str(target / "test" / "*.parquet")},
)
assert len(ds["test"]) == 1
assert "90_degree" in ds["test"][0]["tags"]
try:
_copy_json_png_pair(sample_json, dataset_source)
except AssertionError:
continue
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
parquet_files = list((target / "test").glob("*.parquet"))
if not parquet_files:
continue
ds = load_dataset(
"parquet",
data_files={"test": str(target / "test" / "*.parquet")},
)
if len(ds["test"]) == 1 and "90_degree" in ds["test"][0]["tags"]:
return
pytest.skip("No rotated_90_deg fixture passed current table-region export filters.")
def test_doclingsdg_builder_table_regions_skips_malformed_sample(tmp_path: Path):
@@ -276,6 +292,110 @@ def test_doclingsdg_builder_table_regions_skips_out_of_bounds_table_bbox(
assert not list((target / "test").glob("*.parquet"))
def test_doclingsdg_builder_table_regions_skips_row_region_high_iou_overlap(
tmp_path: Path,
):
dataset_source = tmp_path / "doclingsdg_row_overlap_source"
target = tmp_path / "doclingsdg_row_overlap_target"
dataset_source.mkdir(parents=True)
sample_json = next(Path("tests/data/test_doclingsdg_docs").glob("*.json"))
sample_png = sample_json.with_suffix(".png")
assert sample_png.exists()
shutil.copy2(sample_png, dataset_source / sample_png.name)
with sample_json.open("r", encoding="utf-8") as file_handle:
payload = json.load(file_handle)
cells = payload["tables"][0]["data"]["table_cells"]
row0_cell = next(
(
c
for c in cells
if int(c.get("start_row_offset_idx", -1)) == 0 and c.get("bbox") is not None
),
None,
)
assert row0_cell is not None
ref_t = float(row0_cell["bbox"]["t"])
ref_b = float(row0_cell["bbox"]["b"])
row1_cells = [
c
for c in cells
if int(c.get("start_row_offset_idx", -1)) == 1 and c.get("bbox")
]
assert row1_cells
for cell in row1_cells:
cell["bbox"]["t"] = ref_t
cell["bbox"]["b"] = ref_b
broken_json = dataset_source / sample_json.name
with broken_json.open("w", encoding="utf-8") as file_handle:
json.dump(payload, file_handle)
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
assert not list((target / "test").glob("*.parquet"))
def test_doclingsdg_builder_table_regions_skips_column_region_high_iou_overlap(
tmp_path: Path,
):
dataset_source = tmp_path / "doclingsdg_col_overlap_source"
target = tmp_path / "doclingsdg_col_overlap_target"
dataset_source.mkdir(parents=True)
sample_json = next(Path("tests/data/test_doclingsdg_docs").glob("*.json"))
sample_png = sample_json.with_suffix(".png")
assert sample_png.exists()
shutil.copy2(sample_png, dataset_source / sample_png.name)
with sample_json.open("r", encoding="utf-8") as file_handle:
payload = json.load(file_handle)
cells = payload["tables"][0]["data"]["table_cells"]
col0_cell = next(
(
c
for c in cells
if int(c.get("start_col_offset_idx", -1)) == 0 and c.get("bbox") is not None
),
None,
)
assert col0_cell is not None
ref_l = float(col0_cell["bbox"]["l"])
ref_r = float(col0_cell["bbox"]["r"])
col1_cells = [
c
for c in cells
if int(c.get("start_col_offset_idx", -1)) == 1 and c.get("bbox")
]
assert col1_cells
for cell in col1_cells:
cell["bbox"]["l"] = ref_l
cell["bbox"]["r"] = ref_r
broken_json = dataset_source / sample_json.name
with broken_json.open("w", encoding="utf-8") as file_handle:
json.dump(payload, file_handle)
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
assert not list((target / "test").glob("*.parquet"))
def test_doclingsdg_builder_table_regions_skips_duplicate_left_on_same_row_band(
tmp_path: Path,
):
@@ -332,3 +452,93 @@ def test_doclingsdg_builder_table_regions_uses_table_visualizer_for_viz(
assert viz_file.exists()
content = viz_file.read_text(encoding="utf-8")
assert "Table Regions Visualization" in content
def test_doclingsdg_builder_table_regions_uses_regions_for_viz(tmp_path: Path):
dataset_source = tmp_path / "doclingsdg_viz_regions_source"
target = tmp_path / "doclingsdg_viz_regions_target"
dataset_source.mkdir(parents=True)
sample_json = Path(
"tests/data/test_doclingsdg_docs/"
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
)
_copy_json_png_pair(sample_json, dataset_source)
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
table_regions_visualization_source="regions",
)
builder.save_to_disk(chunk_size=4, do_visualization=True)
viz_file = target / "visualizations" / f"{sample_json.stem}_layout.html"
assert viz_file.exists()
content = viz_file.read_text(encoding="utf-8")
assert "Table Regions Visualization" in content
def test_doclingsdg_builder_table_regions_filters_col_count_out_of_range(
tmp_path: Path,
):
dataset_source = tmp_path / "doclingsdg_col_filter_source"
target = tmp_path / "doclingsdg_col_filter_target"
dataset_source.mkdir(parents=True)
valid_sample = Path(
"tests/data/test_doclingsdg_docs/"
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
)
out_of_range_col_sample = Path(
"tests/data/test_doclingsdg_docs/"
"data_none__seed_teds_0.944_table_table_dataset_20260310_"
"tight_margin_wiki_tables_otsl_en_margin_doc53881_t000_row4_col34__0040.json"
)
_copy_json_png_pair(valid_sample, dataset_source)
_copy_json_png_pair(out_of_range_col_sample, dataset_source)
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
ds = load_dataset(
"parquet",
data_files={"test": str(target / "test" / "*.parquet")},
)
assert len(ds["test"]) == 1
assert ds["test"][0]["document_id"] == valid_sample.stem
def test_doclingsdg_builder_filters_image_dimensions_out_of_range(tmp_path: Path):
dataset_source = tmp_path / "doclingsdg_img_filter_source"
target = tmp_path / "doclingsdg_img_filter_target"
dataset_source.mkdir(parents=True)
sample_json = Path(
"tests/data/test_doclingsdg_docs/"
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
)
json_dst = dataset_source / sample_json.name
shutil.copy2(sample_json, json_dst)
# Keep geometry in bounds while forcing the max-size filter to trigger.
Image.new("RGB", (4097, 4097), "white").save(
dataset_source / f"{sample_json.stem}.png"
)
builder = DoclingSDGDatasetBuilder(
dataset_source=dataset_source,
target=target,
modality="table_regions",
)
builder.save_to_disk(chunk_size=4)
assert not list((target / "test").glob("*.parquet"))