mirror of
https://github.com/docling-project/docling-eval.git
synced 2026-05-17 13:10:47 +00:00
feat: Improved region extraction and filtering for doclingsdg_builder.py (#217)
* Improved region extraction and filtering for doclingsdg_builder.py Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added table region selector by label in doclingsdg_builder Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Added skipping switch for 90 degree tables Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> --------- Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from docling_core.types.doc import (
|
||||
TableItem,
|
||||
)
|
||||
from docling_core.types.io import DocumentStream
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
from pydantic import ValidationError
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -45,6 +45,14 @@ _BBOX_OVERLAP_EPSILON = 1e-3
|
||||
_ROTATION_ASPECT_THRESHOLD = 1.15
|
||||
_COORD_TOLERANCE = 1.0
|
||||
_ROW_LEFT_DUPLICATE_TOLERANCE = 0.5
|
||||
_MIN_GRID_ROWS = 1
|
||||
_MAX_GRID_ROWS = 40
|
||||
_MIN_GRID_COLS = 1
|
||||
_MAX_GRID_COLS = 20
|
||||
_MIN_PAGE_IMAGE_DIM = 32
|
||||
_MAX_PAGE_IMAGE_DIM = 4096
|
||||
_REGION_OVERLAP_IOU_THRESHOLD = 0.99
|
||||
_SKIP_ROTATED_90_TABLES = True
|
||||
|
||||
_TABLE_REGION_CATEGORY_IDS: Dict[str, int] = {
|
||||
"table": 1,
|
||||
@@ -57,6 +65,21 @@ _TABLE_REGION_CATEGORY_IDS: Dict[str, int] = {
|
||||
"cell_merged": 8,
|
||||
}
|
||||
|
||||
_TABLE_REGION_EXPORT_LABELS: Tuple[str, ...] = (
|
||||
"table",
|
||||
"row",
|
||||
"column",
|
||||
"cell_merged",
|
||||
)
|
||||
_TABLE_REGION_EXPORT_LABELS_SET = set(_TABLE_REGION_EXPORT_LABELS)
|
||||
|
||||
_TABLE_REGIONS_VIZ_SOURCE_DOCLING = "docling"
|
||||
_TABLE_REGIONS_VIZ_SOURCE_REGIONS = "regions"
|
||||
_TABLE_REGIONS_VIZ_SOURCES = {
|
||||
_TABLE_REGIONS_VIZ_SOURCE_DOCLING,
|
||||
_TABLE_REGIONS_VIZ_SOURCE_REGIONS,
|
||||
}
|
||||
|
||||
|
||||
class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
"""
|
||||
@@ -76,6 +99,7 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
split: str = "test",
|
||||
begin_index: int = 0,
|
||||
end_index: int = -1,
|
||||
table_regions_visualization_source: str = _TABLE_REGIONS_VIZ_SOURCE_REGIONS,
|
||||
):
|
||||
try:
|
||||
parsed_modality = EvaluationModality(modality)
|
||||
@@ -109,6 +133,31 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
self.modality = parsed_modality
|
||||
self.must_retrieve = False
|
||||
|
||||
if len(_TABLE_REGION_EXPORT_LABELS_SET) != len(_TABLE_REGION_EXPORT_LABELS):
|
||||
raise ValueError(
|
||||
"Duplicate labels detected in _TABLE_REGION_EXPORT_LABELS."
|
||||
)
|
||||
unknown_export_labels = sorted(
|
||||
label
|
||||
for label in _TABLE_REGION_EXPORT_LABELS
|
||||
if label not in _TABLE_REGION_CATEGORY_IDS
|
||||
)
|
||||
if unknown_export_labels:
|
||||
raise ValueError(
|
||||
"Unknown labels in _TABLE_REGION_EXPORT_LABELS: "
|
||||
f"{unknown_export_labels}. Expected subset of "
|
||||
f"{sorted(_TABLE_REGION_CATEGORY_IDS.keys())}"
|
||||
)
|
||||
|
||||
source = str(table_regions_visualization_source).strip().lower()
|
||||
if source not in _TABLE_REGIONS_VIZ_SOURCES:
|
||||
raise ValueError(
|
||||
"Unsupported table regions visualization source "
|
||||
f"'{table_regions_visualization_source}'. Expected one of: "
|
||||
f"{sorted(_TABLE_REGIONS_VIZ_SOURCES)}"
|
||||
)
|
||||
self.table_regions_visualization_source = source
|
||||
|
||||
@staticmethod
|
||||
def _sort_by_page_suffix(path: Path) -> tuple[int, str]:
|
||||
match = _PAGE_SUFFIX_PATTERN.search(path.stem)
|
||||
@@ -321,6 +370,10 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
def _table_region_category_id(label: str) -> int:
|
||||
return _TABLE_REGION_CATEGORY_IDS[label]
|
||||
|
||||
@staticmethod
|
||||
def _should_export_table_region_label(label: str) -> bool:
|
||||
return label in _TABLE_REGION_EXPORT_LABELS_SET
|
||||
|
||||
@staticmethod
|
||||
def _safe_page_height(page: Optional[PageItem], page_image: Image.Image) -> float:
|
||||
if (
|
||||
@@ -422,6 +475,55 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
inter_h = min(rect_a[3], rect_b[3]) - max(rect_a[1], rect_b[1])
|
||||
return inter_w > _BBOX_OVERLAP_EPSILON and inter_h > _BBOX_OVERLAP_EPSILON
|
||||
|
||||
@staticmethod
|
||||
def _rect_iou(
|
||||
rect_a: Tuple[float, float, float, float],
|
||||
rect_b: Tuple[float, float, float, float],
|
||||
) -> float:
|
||||
inter_w = min(rect_a[2], rect_b[2]) - max(rect_a[0], rect_b[0])
|
||||
inter_h = min(rect_a[3], rect_b[3]) - max(rect_a[1], rect_b[1])
|
||||
if inter_w <= 0.0 or inter_h <= 0.0:
|
||||
return 0.0
|
||||
|
||||
inter_area = inter_w * inter_h
|
||||
area_a = max(0.0, rect_a[2] - rect_a[0]) * max(0.0, rect_a[3] - rect_a[1])
|
||||
area_b = max(0.0, rect_b[2] - rect_b[0]) * max(0.0, rect_b[3] - rect_b[1])
|
||||
union = area_a + area_b - inter_area
|
||||
if union <= 0.0:
|
||||
return 0.0
|
||||
return inter_area / union
|
||||
|
||||
@classmethod
|
||||
def _find_region_overlap_iou_issue(
|
||||
cls,
|
||||
*,
|
||||
region_name: str,
|
||||
regions: List[Tuple[int, Tuple[float, float, float, float]]],
|
||||
threshold: float = _REGION_OVERLAP_IOU_THRESHOLD,
|
||||
) -> Optional[str]:
|
||||
for i in range(len(regions)):
|
||||
region_id_a, rect_a = regions[i]
|
||||
for j in range(i + 1, len(regions)):
|
||||
region_id_b, rect_b = regions[j]
|
||||
iou = cls._rect_iou(rect_a, rect_b)
|
||||
if iou > threshold:
|
||||
return (
|
||||
f"Overlapping {region_name} regions with IoU>{threshold:.2f} "
|
||||
f"(ids={region_id_a},{region_id_b}, iou={iou:.4f})."
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _overlap_region_type_from_reason(reason: str) -> Optional[str]:
|
||||
normalized = reason
|
||||
if normalized.startswith("fallback::"):
|
||||
normalized = normalized[len("fallback::") :]
|
||||
if normalized.startswith("Overlapping row regions with IoU>"):
|
||||
return "row"
|
||||
if normalized.startswith("Overlapping column regions with IoU>"):
|
||||
return "column"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_table_rotated_90(
|
||||
row_rects: List[Tuple[float, float, float, float]],
|
||||
@@ -453,12 +555,12 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
row_span: int,
|
||||
col_span: int,
|
||||
) -> str:
|
||||
if is_column_header:
|
||||
return "cell_column_header"
|
||||
if is_row_header:
|
||||
return "cell_row_header"
|
||||
if is_row_section:
|
||||
return "cell_section_header"
|
||||
# if is_column_header:
|
||||
# return "cell_column_header"
|
||||
# if is_row_header:
|
||||
# return "cell_row_header"
|
||||
# if is_row_section:
|
||||
# return "cell_section_header"
|
||||
if row_span > 1 or col_span > 1:
|
||||
return "cell_merged"
|
||||
return "cell_single"
|
||||
@@ -921,12 +1023,6 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
if not table_cells:
|
||||
return {}, has_rotated_90, "Table has no cells."
|
||||
|
||||
row_rects_map: Dict[int, List[Tuple[float, float, float, float]]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
col_rects_map: Dict[int, List[Tuple[float, float, float, float]]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
occupancy: Dict[Tuple[int, int], int] = {}
|
||||
cell_entries: List[Dict[str, Any]] = []
|
||||
|
||||
@@ -967,11 +1063,6 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
"Invalid cell bbox (out of image bounds).",
|
||||
)
|
||||
|
||||
for row_id in range(row_start, row_end):
|
||||
row_rects_map[row_id].append(cell_rect)
|
||||
for col_id in range(col_start, col_end):
|
||||
col_rects_map[col_id].append(cell_rect)
|
||||
|
||||
for row_id in range(row_start, row_end):
|
||||
for col_id in range(col_start, col_end):
|
||||
slot = (row_id, col_id)
|
||||
@@ -1007,9 +1098,61 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
}
|
||||
)
|
||||
|
||||
if not row_rects_map or not col_rects_map:
|
||||
try:
|
||||
row_bboxes = table_data.get_row_bounding_boxes(minimal=False)
|
||||
col_bboxes = table_data.get_column_bounding_boxes(minimal=False)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return (
|
||||
{},
|
||||
has_rotated_90,
|
||||
f"Failed to compute row/column regions from table data: {exc}",
|
||||
)
|
||||
|
||||
row_regions: List[Tuple[int, Tuple[float, float, float, float]]] = []
|
||||
for row_id, row_bbox in sorted(row_bboxes.items()):
|
||||
row_rect = self._normalize_bbox_to_page_image(
|
||||
row_bbox,
|
||||
page_height=page_height,
|
||||
page_image=page_image,
|
||||
reject_out_of_bounds=True,
|
||||
)
|
||||
if row_rect is None:
|
||||
return {}, has_rotated_90, "Invalid row bbox (out of image bounds)."
|
||||
row_regions.append((int(row_id), row_rect))
|
||||
|
||||
col_regions: List[Tuple[int, Tuple[float, float, float, float]]] = []
|
||||
for col_id, col_bbox in sorted(col_bboxes.items()):
|
||||
col_rect = self._normalize_bbox_to_page_image(
|
||||
col_bbox,
|
||||
page_height=page_height,
|
||||
page_image=page_image,
|
||||
reject_out_of_bounds=True,
|
||||
)
|
||||
if col_rect is None:
|
||||
return (
|
||||
{},
|
||||
has_rotated_90,
|
||||
"Invalid column bbox (out of image bounds).",
|
||||
)
|
||||
col_regions.append((int(col_id), col_rect))
|
||||
|
||||
if not row_regions or not col_regions:
|
||||
return {}, has_rotated_90, "Missing row/column regions."
|
||||
|
||||
row_overlap_issue = self._find_region_overlap_iou_issue(
|
||||
region_name="row",
|
||||
regions=row_regions,
|
||||
)
|
||||
if row_overlap_issue is not None:
|
||||
return {}, has_rotated_90, row_overlap_issue
|
||||
|
||||
col_overlap_issue = self._find_region_overlap_iou_issue(
|
||||
region_name="column",
|
||||
regions=col_regions,
|
||||
)
|
||||
if col_overlap_issue is not None:
|
||||
return {}, has_rotated_90, col_overlap_issue
|
||||
|
||||
active_cells: List[Dict[str, Any]] = []
|
||||
for entry in sorted(cell_entries, key=lambda e: e["rect"][0]):
|
||||
left = float(entry["rect"][0])
|
||||
@@ -1049,39 +1192,40 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
active_cells.append(entry)
|
||||
|
||||
page_boxes = bboxes_by_page.setdefault(page_index, [])
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="table",
|
||||
category_id=self._table_region_category_id("table"),
|
||||
rect=table_rect,
|
||||
if self._should_export_table_region_label("table"):
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="table",
|
||||
category_id=self._table_region_category_id("table"),
|
||||
rect=table_rect,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
row_rects: List[Tuple[float, float, float, float]] = []
|
||||
for row_id in sorted(row_rects_map):
|
||||
row_rect = self._union_rectangles(row_rects_map[row_id])
|
||||
for row_id, row_rect in row_regions:
|
||||
row_rects.append(row_rect)
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="row",
|
||||
category_id=self._table_region_category_id("row"),
|
||||
rect=row_rect,
|
||||
extras={"row_id": row_id},
|
||||
if self._should_export_table_region_label("row"):
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="row",
|
||||
category_id=self._table_region_category_id("row"),
|
||||
rect=row_rect,
|
||||
extras={"row_id": row_id},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
col_rects: List[Tuple[float, float, float, float]] = []
|
||||
for col_id in sorted(col_rects_map):
|
||||
col_rect = self._union_rectangles(col_rects_map[col_id])
|
||||
for col_id, col_rect in col_regions:
|
||||
col_rects.append(col_rect)
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="column",
|
||||
category_id=self._table_region_category_id("column"),
|
||||
rect=col_rect,
|
||||
extras={"col_id": col_id},
|
||||
if self._should_export_table_region_label("column"):
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label="column",
|
||||
category_id=self._table_region_category_id("column"),
|
||||
rect=col_rect,
|
||||
extras={"col_id": col_id},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self._is_table_rotated_90(row_rects=row_rects, col_rects=col_rects):
|
||||
has_rotated_90 = True
|
||||
@@ -1091,10 +1235,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
return {}, has_rotated_90, row_left_issue
|
||||
|
||||
for entry in cell_entries:
|
||||
cell_label = str(entry["label"])
|
||||
if not self._should_export_table_region_label(cell_label):
|
||||
continue
|
||||
page_boxes.append(
|
||||
self._bbox_payload(
|
||||
label=str(entry["label"]),
|
||||
category_id=self._table_region_category_id(str(entry["label"])),
|
||||
label=cell_label,
|
||||
category_id=self._table_region_category_id(cell_label),
|
||||
rect=entry["rect"],
|
||||
extras={
|
||||
"text": entry["text"] if entry["text"] is not None else "",
|
||||
@@ -1165,6 +1312,146 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
with open(filename, "w", encoding="utf-8") as file_handle:
|
||||
file_handle.write("\n".join(html_parts))
|
||||
|
||||
@staticmethod
|
||||
def _region_viz_style(label: str) -> Tuple[Tuple[int, int, int, int], int]:
|
||||
if label == "table":
|
||||
return (220, 38, 38, 255), 3
|
||||
if label == "row":
|
||||
return (34, 197, 94, 220), 2
|
||||
if label == "column":
|
||||
return (37, 99, 235, 220), 2
|
||||
if label == "cell_column_header":
|
||||
return (0, 0, 255, 180), 2
|
||||
if label == "cell_row_header":
|
||||
return (255, 0, 0, 180), 2
|
||||
if label == "cell_section_header":
|
||||
return (255, 0, 255, 180), 2
|
||||
if label == "cell_merged":
|
||||
return (255, 0, 0, 180), 2
|
||||
if label == "cell_single":
|
||||
return (0, 0, 255, 180), 2
|
||||
elif label.startswith("cell_"):
|
||||
return (249, 115, 22, 180), 1
|
||||
return (107, 114, 128, 180), 1
|
||||
|
||||
@staticmethod
|
||||
def _rect_from_bbox_entry(
|
||||
box: Dict[str, Any],
|
||||
) -> Optional[Tuple[float, float, float, float]]:
|
||||
ltrb_raw = box.get("ltrb")
|
||||
if isinstance(ltrb_raw, list) and len(ltrb_raw) == 4:
|
||||
try:
|
||||
left, top, right, bottom = [float(v) for v in ltrb_raw]
|
||||
if right > left and bottom > top:
|
||||
return left, top, right, bottom
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
bbox_raw = box.get("bbox")
|
||||
if isinstance(bbox_raw, list) and len(bbox_raw) == 4:
|
||||
try:
|
||||
left, top, width, height = [float(v) for v in bbox_raw]
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
if width > 0 and height > 0:
|
||||
return left, top, right, bottom
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _table_regions_visualization_from_regions(
|
||||
self, record: DatasetRecord
|
||||
) -> Dict[Optional[int], Image.Image]:
|
||||
if not isinstance(record, DatasetRecordWithBBox):
|
||||
raise ValueError(
|
||||
"Region-based visualization requires DatasetRecordWithBBox."
|
||||
)
|
||||
|
||||
page_visualizations: Dict[Optional[int], Image.Image] = {}
|
||||
boxes_by_page = record.ground_truth_bbox_on_page_images
|
||||
|
||||
draw_order = {"table": 0, "row": 1, "column": 2}
|
||||
for page_index, page_image in enumerate(record.ground_truth_page_images):
|
||||
page_canvas = page_image.convert("RGB").copy()
|
||||
draw = ImageDraw.Draw(page_canvas, mode="RGBA")
|
||||
|
||||
page_boxes = list(boxes_by_page.get(page_index, []))
|
||||
page_boxes.sort(
|
||||
key=lambda box: draw_order.get(str(box.get("label", "")), 3)
|
||||
)
|
||||
|
||||
for box in page_boxes:
|
||||
label = str(box.get("label", ""))
|
||||
# if label not in ["row", "column"]:
|
||||
# if label == "cell_section_header":
|
||||
# if label == "column":
|
||||
if label.startswith("cell_"):
|
||||
rect = self._rect_from_bbox_entry(box)
|
||||
if rect is None:
|
||||
continue
|
||||
color, width = self._region_viz_style(label)
|
||||
fill_alpha = 35 if label in ("row", "column") else 0
|
||||
fill = (color[0], color[1], color[2], fill_alpha)
|
||||
draw.rectangle(
|
||||
[(rect[0], rect[1]), (rect[2], rect[3])],
|
||||
outline=color,
|
||||
fill=fill if fill_alpha > 0 else None,
|
||||
width=width,
|
||||
)
|
||||
|
||||
page_visualizations[page_index] = page_canvas
|
||||
|
||||
if not page_visualizations:
|
||||
page_visualizations[None] = get_missing_pageimg()
|
||||
|
||||
return page_visualizations
|
||||
|
||||
def _table_regions_visualization_from_docling(
|
||||
self, record: DatasetRecord
|
||||
) -> Dict[Optional[int], Image.Image]:
|
||||
table_doc = insert_images_from_pil(
|
||||
document=copy.deepcopy(record.ground_truth_doc),
|
||||
pictures=record.ground_truth_pictures,
|
||||
page_images=record.ground_truth_page_images,
|
||||
)
|
||||
table_visualizer = TableVisualizer(
|
||||
params=TableVisualizer.Params(
|
||||
show_cells=True,
|
||||
show_rows=True,
|
||||
show_cols=True,
|
||||
minimal_row_bboxes=True,
|
||||
minimal_col_bboxes=True,
|
||||
)
|
||||
)
|
||||
return table_visualizer.get_visualization(doc=table_doc)
|
||||
|
||||
@staticmethod
|
||||
def _count_row_column_regions(
|
||||
bboxes_by_page: Dict[int, List[Dict[str, Any]]],
|
||||
) -> Tuple[int, int]:
|
||||
row_count = 0
|
||||
col_count = 0
|
||||
for page_boxes in bboxes_by_page.values():
|
||||
for box in page_boxes:
|
||||
label = str(box.get("label", ""))
|
||||
if label == "row":
|
||||
row_count += 1
|
||||
elif label == "column":
|
||||
col_count += 1
|
||||
return row_count, col_count
|
||||
|
||||
@staticmethod
|
||||
def _page_images_within_limits(page_images: List[Image.Image]) -> bool:
|
||||
for image in page_images:
|
||||
width, height = image.size
|
||||
if not (
|
||||
_MIN_PAGE_IMAGE_DIM <= int(width) <= _MAX_PAGE_IMAGE_DIM
|
||||
and _MIN_PAGE_IMAGE_DIM <= int(height) <= _MAX_PAGE_IMAGE_DIM
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def save_ground_truth_visualization(
|
||||
self,
|
||||
record: DatasetRecord,
|
||||
@@ -1174,26 +1461,25 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
super().save_ground_truth_visualization(record, viz_path_split)
|
||||
return
|
||||
|
||||
# Use TableVisualizer output for table_regions instead of layout visualizer.
|
||||
table_doc = insert_images_from_pil(
|
||||
document=copy.deepcopy(record.ground_truth_doc),
|
||||
pictures=record.ground_truth_pictures,
|
||||
page_images=record.ground_truth_page_images,
|
||||
)
|
||||
table_visualizer = TableVisualizer(
|
||||
params=TableVisualizer.Params(
|
||||
show_cells=True,
|
||||
show_rows=False,
|
||||
show_cols=False,
|
||||
minimal_row_bboxes=False,
|
||||
minimal_col_bboxes=False,
|
||||
)
|
||||
)
|
||||
# For table_regions, allow visualization from either extracted regions
|
||||
# or the Docling document table visualizer.
|
||||
try:
|
||||
page_visualizations = table_visualizer.get_visualization(doc=table_doc)
|
||||
if (
|
||||
self.table_regions_visualization_source
|
||||
== _TABLE_REGIONS_VIZ_SOURCE_REGIONS
|
||||
):
|
||||
page_visualizations = self._table_regions_visualization_from_regions(
|
||||
record
|
||||
)
|
||||
else:
|
||||
page_visualizations = self._table_regions_visualization_from_docling(
|
||||
record
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
_log.warning(
|
||||
"TableVisualizer failed for %s: %s. Falling back to default visualization.",
|
||||
"Table regions visualization (%s) failed for %s: %s. "
|
||||
"Falling back to default visualization.",
|
||||
self.table_regions_visualization_source,
|
||||
record.doc_id,
|
||||
exc,
|
||||
)
|
||||
@@ -1231,6 +1517,10 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
skipped_png_errors = 0
|
||||
skipped_malformed = 0
|
||||
skipped_no_table = 0
|
||||
skipped_filtered = 0
|
||||
skipped_row_overlap = 0
|
||||
skipped_col_overlap = 0
|
||||
filter_reason_counts: Dict[str, int] = defaultdict(int)
|
||||
malformed_reason_counts: Dict[str, int] = defaultdict(int)
|
||||
|
||||
for json_path in tqdm(
|
||||
@@ -1343,6 +1633,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
malformed_reason,
|
||||
)
|
||||
skipped_malformed += 1
|
||||
overlap_region_type = self._overlap_region_type_from_reason(
|
||||
malformed_reason
|
||||
)
|
||||
if overlap_region_type == "row":
|
||||
skipped_row_overlap += 1
|
||||
elif overlap_region_type == "column":
|
||||
skipped_col_overlap += 1
|
||||
if malformed_reason == "Document has no table items.":
|
||||
skipped_no_table += 1
|
||||
malformed_reason_counts[malformed_reason] = (
|
||||
@@ -1418,6 +1715,13 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
fallback_reason,
|
||||
)
|
||||
skipped_malformed += 1
|
||||
overlap_region_type = self._overlap_region_type_from_reason(
|
||||
fallback_reason
|
||||
)
|
||||
if overlap_region_type == "row":
|
||||
skipped_row_overlap += 1
|
||||
elif overlap_region_type == "column":
|
||||
skipped_col_overlap += 1
|
||||
malformed_reason_counts[f"fallback::{fallback_reason}"] = (
|
||||
malformed_reason_counts.get(
|
||||
f"fallback::{fallback_reason}", 0
|
||||
@@ -1433,6 +1737,16 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
)
|
||||
|
||||
if has_rotated_90:
|
||||
if _SKIP_ROTATED_90_TABLES:
|
||||
skipped_filtered += 1
|
||||
filter_reason_counts["rotated_90_table"] = (
|
||||
filter_reason_counts.get("rotated_90_table", 0) + 1
|
||||
)
|
||||
_log.warning(
|
||||
"Skipping table sample %s: 90-degree rotated table detected.",
|
||||
doc_id,
|
||||
)
|
||||
continue
|
||||
tags.append("90_degree")
|
||||
else:
|
||||
ground_truth_bboxes = self._extract_top_level_bboxes(
|
||||
@@ -1441,6 +1755,52 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
page_images=extracted_page_images,
|
||||
)
|
||||
|
||||
if self.modality == EvaluationModality.TABLE_REGIONS:
|
||||
row_count, col_count = self._count_row_column_regions(
|
||||
ground_truth_bboxes
|
||||
)
|
||||
if not (_MIN_GRID_ROWS <= row_count <= _MAX_GRID_ROWS):
|
||||
skipped_filtered += 1
|
||||
filter_reason_counts[f"rows_out_of_range::{row_count}"] = (
|
||||
filter_reason_counts.get(f"rows_out_of_range::{row_count}", 0)
|
||||
+ 1
|
||||
)
|
||||
_log.warning(
|
||||
"Skipping table sample %s: row region count %d is outside [%d, %d].",
|
||||
doc_id,
|
||||
row_count,
|
||||
_MIN_GRID_ROWS,
|
||||
_MAX_GRID_ROWS,
|
||||
)
|
||||
continue
|
||||
if not (_MIN_GRID_COLS <= col_count <= _MAX_GRID_COLS):
|
||||
skipped_filtered += 1
|
||||
filter_reason_counts[f"cols_out_of_range::{col_count}"] = (
|
||||
filter_reason_counts.get(f"cols_out_of_range::{col_count}", 0)
|
||||
+ 1
|
||||
)
|
||||
_log.warning(
|
||||
"Skipping table sample %s: column region count %d is outside [%d, %d].",
|
||||
doc_id,
|
||||
col_count,
|
||||
_MIN_GRID_COLS,
|
||||
_MAX_GRID_COLS,
|
||||
)
|
||||
continue
|
||||
|
||||
if not self._page_images_within_limits(extracted_page_images):
|
||||
skipped_filtered += 1
|
||||
filter_reason_counts["page_image_size_out_of_range"] = (
|
||||
filter_reason_counts.get("page_image_size_out_of_range", 0) + 1
|
||||
)
|
||||
_log.warning(
|
||||
"Skipping sample %s: page image size outside [%d, %d] px.",
|
||||
doc_id,
|
||||
_MIN_PAGE_IMAGE_DIM,
|
||||
_MAX_PAGE_IMAGE_DIM,
|
||||
)
|
||||
continue
|
||||
|
||||
if len(png_files) == 1:
|
||||
original_bytes = get_binary(png_files[0])
|
||||
original_stream = DocumentStream(
|
||||
@@ -1476,7 +1836,7 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
(
|
||||
"DoclingSDG processing summary (modality=%s): "
|
||||
"processed=%d, exported=%d, skipped=%d "
|
||||
"(load_errors=%d, missing_png=%d, png_read_errors=%d, malformed=%d, no_table=%d)"
|
||||
"(load_errors=%d, missing_png=%d, png_read_errors=%d, malformed=%d, no_table=%d, filtered=%d)"
|
||||
),
|
||||
self.modality.value,
|
||||
processed_docs,
|
||||
@@ -1487,9 +1847,22 @@ class DoclingSDGDatasetBuilder(BaseEvaluationDatasetBuilder):
|
||||
skipped_png_errors,
|
||||
skipped_malformed,
|
||||
skipped_no_table,
|
||||
skipped_filtered,
|
||||
)
|
||||
if malformed_reason_counts:
|
||||
_log.info(
|
||||
"DoclingSDG malformed reason counts: %s",
|
||||
dict(sorted(malformed_reason_counts.items())),
|
||||
)
|
||||
if filter_reason_counts:
|
||||
_log.info(
|
||||
"DoclingSDG filtered reason counts: %s",
|
||||
dict(sorted(filter_reason_counts.items())),
|
||||
)
|
||||
skipped_region_overlap = skipped_row_overlap + skipped_col_overlap
|
||||
_log.info(
|
||||
"DoclingSDG row/column-overlap skips: total=%d (row=%d, column=%d)",
|
||||
skipped_region_overlap,
|
||||
skipped_row_overlap,
|
||||
skipped_col_overlap,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,10 @@ from PIL import Image
|
||||
|
||||
from docling_eval.datamodels.dataset_record import DatasetRecordWithBBox
|
||||
from docling_eval.datamodels.types import BenchMarkNames
|
||||
from docling_eval.dataset_builders.doclingsdg_builder import DoclingSDGDatasetBuilder
|
||||
from docling_eval.dataset_builders.doclingsdg_builder import (
|
||||
_TABLE_REGION_EXPORT_LABELS,
|
||||
DoclingSDGDatasetBuilder,
|
||||
)
|
||||
|
||||
|
||||
def _copy_json_png_pair(source_json: Path, target_dir: Path) -> None:
|
||||
@@ -183,35 +186,48 @@ def test_doclingsdg_builder_table_regions_bbox_labels(tmp_path: Path):
|
||||
|
||||
restored = DatasetRecordWithBBox.model_validate(row)
|
||||
labels = {box["label"] for box in restored.ground_truth_bbox_on_page_images[0]}
|
||||
assert labels.issubset(set(_TABLE_REGION_EXPORT_LABELS))
|
||||
assert "table" in labels
|
||||
assert "row" in labels
|
||||
assert "column" in labels
|
||||
assert "cell_single" in labels
|
||||
assert "cell_merged" in labels
|
||||
if "cell_merged" in _TABLE_REGION_EXPORT_LABELS:
|
||||
assert "cell_merged" in labels
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_adds_90_degree_tag(tmp_path: Path):
|
||||
dataset_source = tmp_path / "doclingsdg_rot_source"
|
||||
target = tmp_path / "doclingsdg_rot_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
rotated_source = Path("EXAMPLE_DOCLING_SDG_TABLES/rotated_90_deg")
|
||||
sample_json = sorted(rotated_source.glob("*.json"))[0]
|
||||
_copy_json_png_pair(sample_json, dataset_source)
|
||||
rotated_jsons = sorted(rotated_source.glob("*.json"))
|
||||
assert rotated_jsons
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
for idx, sample_json in enumerate(rotated_jsons):
|
||||
dataset_source = tmp_path / f"doclingsdg_rot_source_{idx}"
|
||||
target = tmp_path / f"doclingsdg_rot_target_{idx}"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files={"test": str(target / "test" / "*.parquet")},
|
||||
)
|
||||
assert len(ds["test"]) == 1
|
||||
assert "90_degree" in ds["test"][0]["tags"]
|
||||
try:
|
||||
_copy_json_png_pair(sample_json, dataset_source)
|
||||
except AssertionError:
|
||||
continue
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
|
||||
parquet_files = list((target / "test").glob("*.parquet"))
|
||||
if not parquet_files:
|
||||
continue
|
||||
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files={"test": str(target / "test" / "*.parquet")},
|
||||
)
|
||||
if len(ds["test"]) == 1 and "90_degree" in ds["test"][0]["tags"]:
|
||||
return
|
||||
|
||||
pytest.skip("No rotated_90_deg fixture passed current table-region export filters.")
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_skips_malformed_sample(tmp_path: Path):
|
||||
@@ -276,6 +292,110 @@ def test_doclingsdg_builder_table_regions_skips_out_of_bounds_table_bbox(
|
||||
assert not list((target / "test").glob("*.parquet"))
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_skips_row_region_high_iou_overlap(
|
||||
tmp_path: Path,
|
||||
):
|
||||
dataset_source = tmp_path / "doclingsdg_row_overlap_source"
|
||||
target = tmp_path / "doclingsdg_row_overlap_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
sample_json = next(Path("tests/data/test_doclingsdg_docs").glob("*.json"))
|
||||
sample_png = sample_json.with_suffix(".png")
|
||||
assert sample_png.exists()
|
||||
shutil.copy2(sample_png, dataset_source / sample_png.name)
|
||||
|
||||
with sample_json.open("r", encoding="utf-8") as file_handle:
|
||||
payload = json.load(file_handle)
|
||||
|
||||
cells = payload["tables"][0]["data"]["table_cells"]
|
||||
row0_cell = next(
|
||||
(
|
||||
c
|
||||
for c in cells
|
||||
if int(c.get("start_row_offset_idx", -1)) == 0 and c.get("bbox") is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert row0_cell is not None
|
||||
ref_t = float(row0_cell["bbox"]["t"])
|
||||
ref_b = float(row0_cell["bbox"]["b"])
|
||||
|
||||
row1_cells = [
|
||||
c
|
||||
for c in cells
|
||||
if int(c.get("start_row_offset_idx", -1)) == 1 and c.get("bbox")
|
||||
]
|
||||
assert row1_cells
|
||||
for cell in row1_cells:
|
||||
cell["bbox"]["t"] = ref_t
|
||||
cell["bbox"]["b"] = ref_b
|
||||
|
||||
broken_json = dataset_source / sample_json.name
|
||||
with broken_json.open("w", encoding="utf-8") as file_handle:
|
||||
json.dump(payload, file_handle)
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
|
||||
assert not list((target / "test").glob("*.parquet"))
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_skips_column_region_high_iou_overlap(
|
||||
tmp_path: Path,
|
||||
):
|
||||
dataset_source = tmp_path / "doclingsdg_col_overlap_source"
|
||||
target = tmp_path / "doclingsdg_col_overlap_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
sample_json = next(Path("tests/data/test_doclingsdg_docs").glob("*.json"))
|
||||
sample_png = sample_json.with_suffix(".png")
|
||||
assert sample_png.exists()
|
||||
shutil.copy2(sample_png, dataset_source / sample_png.name)
|
||||
|
||||
with sample_json.open("r", encoding="utf-8") as file_handle:
|
||||
payload = json.load(file_handle)
|
||||
|
||||
cells = payload["tables"][0]["data"]["table_cells"]
|
||||
col0_cell = next(
|
||||
(
|
||||
c
|
||||
for c in cells
|
||||
if int(c.get("start_col_offset_idx", -1)) == 0 and c.get("bbox") is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert col0_cell is not None
|
||||
ref_l = float(col0_cell["bbox"]["l"])
|
||||
ref_r = float(col0_cell["bbox"]["r"])
|
||||
|
||||
col1_cells = [
|
||||
c
|
||||
for c in cells
|
||||
if int(c.get("start_col_offset_idx", -1)) == 1 and c.get("bbox")
|
||||
]
|
||||
assert col1_cells
|
||||
for cell in col1_cells:
|
||||
cell["bbox"]["l"] = ref_l
|
||||
cell["bbox"]["r"] = ref_r
|
||||
|
||||
broken_json = dataset_source / sample_json.name
|
||||
with broken_json.open("w", encoding="utf-8") as file_handle:
|
||||
json.dump(payload, file_handle)
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
|
||||
assert not list((target / "test").glob("*.parquet"))
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_skips_duplicate_left_on_same_row_band(
|
||||
tmp_path: Path,
|
||||
):
|
||||
@@ -332,3 +452,93 @@ def test_doclingsdg_builder_table_regions_uses_table_visualizer_for_viz(
|
||||
assert viz_file.exists()
|
||||
content = viz_file.read_text(encoding="utf-8")
|
||||
assert "Table Regions Visualization" in content
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_uses_regions_for_viz(tmp_path: Path):
|
||||
dataset_source = tmp_path / "doclingsdg_viz_regions_source"
|
||||
target = tmp_path / "doclingsdg_viz_regions_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
sample_json = Path(
|
||||
"tests/data/test_doclingsdg_docs/"
|
||||
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
|
||||
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
|
||||
)
|
||||
_copy_json_png_pair(sample_json, dataset_source)
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
table_regions_visualization_source="regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4, do_visualization=True)
|
||||
|
||||
viz_file = target / "visualizations" / f"{sample_json.stem}_layout.html"
|
||||
assert viz_file.exists()
|
||||
content = viz_file.read_text(encoding="utf-8")
|
||||
assert "Table Regions Visualization" in content
|
||||
|
||||
|
||||
def test_doclingsdg_builder_table_regions_filters_col_count_out_of_range(
|
||||
tmp_path: Path,
|
||||
):
|
||||
dataset_source = tmp_path / "doclingsdg_col_filter_source"
|
||||
target = tmp_path / "doclingsdg_col_filter_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
valid_sample = Path(
|
||||
"tests/data/test_doclingsdg_docs/"
|
||||
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
|
||||
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
|
||||
)
|
||||
out_of_range_col_sample = Path(
|
||||
"tests/data/test_doclingsdg_docs/"
|
||||
"data_none__seed_teds_0.944_table_table_dataset_20260310_"
|
||||
"tight_margin_wiki_tables_otsl_en_margin_doc53881_t000_row4_col34__0040.json"
|
||||
)
|
||||
|
||||
_copy_json_png_pair(valid_sample, dataset_source)
|
||||
_copy_json_png_pair(out_of_range_col_sample, dataset_source)
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files={"test": str(target / "test" / "*.parquet")},
|
||||
)
|
||||
assert len(ds["test"]) == 1
|
||||
assert ds["test"][0]["document_id"] == valid_sample.stem
|
||||
|
||||
|
||||
def test_doclingsdg_builder_filters_image_dimensions_out_of_range(tmp_path: Path):
|
||||
dataset_source = tmp_path / "doclingsdg_img_filter_source"
|
||||
target = tmp_path / "doclingsdg_img_filter_target"
|
||||
dataset_source.mkdir(parents=True)
|
||||
|
||||
sample_json = Path(
|
||||
"tests/data/test_doclingsdg_docs/"
|
||||
"data_none__seed_teds_0.940_table_table_dataset_20260310_"
|
||||
"tight_margin_ftn_margin_doc20169_t000_row35_col8__0736.json"
|
||||
)
|
||||
json_dst = dataset_source / sample_json.name
|
||||
shutil.copy2(sample_json, json_dst)
|
||||
|
||||
# Keep geometry in bounds while forcing the max-size filter to trigger.
|
||||
Image.new("RGB", (4097, 4097), "white").save(
|
||||
dataset_source / f"{sample_json.stem}.png"
|
||||
)
|
||||
|
||||
builder = DoclingSDGDatasetBuilder(
|
||||
dataset_source=dataset_source,
|
||||
target=target,
|
||||
modality="table_regions",
|
||||
)
|
||||
builder.save_to_disk(chunk_size=4)
|
||||
|
||||
assert not list((target / "test").glob("*.parquet"))
|
||||
|
||||
Reference in New Issue
Block a user