feat(vlm): upgrade Granite Vision model to 4.1 for table + chart extraction (#3382)

* feat(table-structure): swap VLM model to granite-vision-4.1-4b

Updates GraniteVisionTableStructureModel to use the 4.1 model. The 4.1
weights are pre-merged, so merge_lora_adapters() is now hasattr-guarded.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* feat(chart-extraction): swap V4 VLM model to granite-vision-4.1-4b

Updates ChartExtractionModelGraniteVisionV4 to use the 4.1 model.
hasattr-guards the merge_lora_adapters() call since 4.1 weights are
pre-merged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* docs(example): mention granite-vision-4.1-4b in table-structure example

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* docs(catalog): update Granite Vision entry to 4.1-4b

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* feat(chart-extraction): honor cuda_use_flash_attention2 in V4 loader

Mirrors the table-structure loader so ChartExtractionModelGraniteVisionV4
also passes _attn_implementation based on AcceleratorOptions. Without this
the chart model falls back to the transformers SDPA default, which can
hit cuDNN backend failures on some torch/cuDNN stacks while the table
model (which already passed the flag) runs cleanly.

Stores accelerator_options on the base class so subclasses can read it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* fix(model-downloader): update Granite Vision log message to 4.1

The log message in download_models still mentioned "Granite Vision 4.0"
after the model swap. Correct it to match the current model version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

* fix(chart-extraction): fall back to bare CSV when V4 model omits ```csv``` fence

granite-vision-4.1-4b sometimes emits raw CSV without a ```csv``` code fence
for the <chart2csv> prompt, which caused _extract_csv_to_dataframe to raise
ValueError and drop the chart's tabular_chart metadata. Mirror the tolerant
parsing already used by the v3 class: prefer a fenced block, otherwise strip
any stray backtick prefix/suffix and parse the text as-is.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>

---------

Signed-off-by: Eli Schwartz <eliyahu.schwartz@ibm.com>
Co-authored-by: Eli Schwartz <eliyahu.schwartz@ibm.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
EliSchwartz
2026-05-04 09:36:08 +03:00
committed by GitHub
parent eb4724ee4c
commit 24f2d148d9
5 changed files with 29 additions and 16 deletions
@@ -54,6 +54,7 @@ class _BaseChartExtractionModelGraniteVision(BaseItemAndImageEnrichmentModel):
):
self.enabled = enabled
self.options = options
self.accelerator_options = accelerator_options
if self.enabled:
self.device = decide_device(
@@ -338,9 +339,9 @@ class ChartExtractionModelGraniteVision(_BaseChartExtractionModelGraniteVision):
class ChartExtractionModelGraniteVisionV4(_BaseChartExtractionModelGraniteVision):
_model_repo_folder = "ibm-granite--granite-4.0-3b-vision"
_model_repo_id = "ibm-granite/granite-4.0-3b-vision"
_model_repo_revision = "f0d034897bae1cd438c961c8c170a3a3089ebf01"
_model_repo_folder = "ibm-granite--granite-vision-4.1-4b"
_model_repo_id = "ibm-granite/granite-vision-4.1-4b"
_model_repo_revision = "dd48e97503de471803850df70843cf9eb5da8712"
def _load_model(self, artifacts_path: Path) -> None:
with warnings.catch_warnings():
@@ -363,9 +364,16 @@ class ChartExtractionModelGraniteVisionV4(_BaseChartExtractionModelGraniteVision
artifacts_path,
device_map=self.device,
dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and self.accelerator_options.cuda_use_flash_attention2
else "sdpa"
),
trust_remote_code=True,
)
cast(Any, self._model).merge_lora_adapters()
if hasattr(self._model, "merge_lora_adapters"):
cast(Any, self._model).merge_lora_adapters()
self._model.eval()
def __call__(
@@ -475,11 +483,15 @@ class ChartExtractionModelGraniteVisionV4(_BaseChartExtractionModelGraniteVision
yield item
def _extract_csv_to_dataframe(self, decoded_text: str) -> pd.DataFrame:
# decoded_text is already the raw generated output (no conversation wrapper)
# decoded_text is already the raw generated output (no conversation wrapper).
# Prefer a fenced ```csv ... ``` block, but fall back to bare CSV if the
# model omits the fence (observed with granite-vision-4.1-4b).
csv_match = re.search(r"```csv\s*\n(.*?)\n```", decoded_text, re.DOTALL)
if not csv_match:
raise ValueError("No ```csv``` block found in model output")
csv_content = csv_match.group(1).strip()
if csv_match:
csv_content = csv_match.group(1).strip()
else:
csv_content = re.sub(r"^```+(?:csv)?\s*", "", decoded_text.strip())
csv_content = re.sub(r"```+\s*$", "", csv_content).strip()
try:
return pd.read_csv(StringIO(csv_content), header=None)
except Exception as e:
@@ -140,11 +140,11 @@ def _parse_otsl_output(
class GraniteVisionTableStructureModel(BaseTableStructureModel):
"""Table structure model using ibm-granite/granite-4.0-3b-vision with <tables_otsl>."""
"""Table structure model using ibm-granite/granite-vision-4.1-4b with <tables_otsl>."""
_model_repo_id: ClassVar[str] = "ibm-granite/granite-4.0-3b-vision"
_model_repo_folder: ClassVar[str] = "ibm-granite--granite-4.0-3b-vision"
_model_repo_revision: ClassVar[str] = "f0d034897bae1cd438c961c8c170a3a3089ebf01"
_model_repo_id: ClassVar[str] = "ibm-granite/granite-vision-4.1-4b"
_model_repo_folder: ClassVar[str] = "ibm-granite--granite-vision-4.1-4b"
_model_repo_revision: ClassVar[str] = "dd48e97503de471803850df70843cf9eb5da8712"
def __init__(
self,
@@ -224,7 +224,8 @@ class GraniteVisionTableStructureModel(BaseTableStructureModel):
),
trust_remote_code=True,
)
cast(Any, self._model).merge_lora_adapters()
if hasattr(self._model, "merge_lora_adapters"):
cast(Any, self._model).merge_lora_adapters()
self._model.eval()
def predict_tables(
+1 -1
View File
@@ -174,7 +174,7 @@ def download_models(
)
if with_granite_chart_extraction_v4:
_log.info("Downloading Granite Vision 4.0 Charts Extraction model...")
_log.info("Downloading Granite Vision 4.1 Charts Extraction model...")
ChartExtractionModelGraniteVisionV4.download_models(
local_dir=output_dir
/ ChartExtractionModelGraniteVisionV4._model_repo_folder,
+1 -1
View File
@@ -17,7 +17,7 @@
# - Defaults to `tests/data/pdf/2206.01062.pdf`. Change `input_doc_path` as needed.
#
# Notes
# - The Granite Vision model (`ibm-granite/granite-4.0-3b-vision`) is downloaded
# - The Granite Vision model (`ibm-granite/granite-vision-4.1-4b`) is downloaded
# automatically from HuggingFace on first run.
# - The model outputs table structure in OTSL (Open Table Structure Language) format,
# which Docling parses into structured table cells.
+1 -1
View File
@@ -88,7 +88,7 @@ The following table shows all processing stages in Docling, their model families
<td rowspan="3">Vision-Language Model<br/>(Granite Vision)</td>
<td>
<ul>
<li><code>granite-4.0-3b-vision</code></li>
<li><code>granite-vision-4.1-4b</code></li>
</ul>
</td>
</tr>