Files
docling/tests/test_asr_mlx_whisper.py
geoHeil eb4724ee4c ci: prototype tach-based modular skipping (#3333)
* ci: prototype tach-based modular skipping

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: modularize ubuntu setup and refine gating

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: adopt metaxy-inspired governance helpers

- replace custom aggregate check with re-actors/alls-green

- set FORCE_JAVASCRIPT_ACTIONS_TO_NODE24 on every workflow

- keep PR concurrency alive when the graphite:merge label is present

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: tune checks and pin action versions

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: split CI suites and heavy examples

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: ecaa4777886157d5c2a7b3893c3a820983089dbf
I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: d15416f3ca94ac97af2a8317cd6404208db9d896

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: sharpen tach graph and per-suite path filters

- Split docling.pipeline into per-pipeline tach modules
  (asr, vlm, standard_pdf, threaded_standard_pdf, legacy_standard_pdf,
  extraction_vlm, base, base_extraction, simple) so pytest --tach-base
  impact analysis can attribute changes to a specific pipeline rather
  than the whole package.
- Split the asr- and vlm-specific docling.datamodel option files
  (asr_model_specs, pipeline_options_asr_model, vlm_engine_options,
  vlm_model_specs, pipeline_options_vlm_model, layout_model_specs,
  stage_model_specs, backend_options) into their own tach modules so
  a narrow spec/options change no longer marks the full datamodel as
  impacted.
- Narrow the per-suite pipeline path filters in checks.yml to the
  concrete pipeline files relevant to each suite, so editing
  vlm_pipeline.py only triggers the vlm matrix cell and editing
  asr_pipeline.py only the asr one.
- Rekey the model cache in setup-ubuntu-ci to include runner.os and
  hashFiles(uv.lock, pyproject.toml), with ordered restore-keys
  fallbacks so a lockfile bump no longer silently stales the cache.

Metaxy parity note: layered tach enforcement (layer = "...") is
blocked by existing backend<->datamodel and utils<->stages cycles;
depot runners, nox dynamic matrices, devenv/nix, dprint and ty are
not applicable to docling's stack. All pinned action SHAs are on
their latest release as of this commit.

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: introduce pipeline and orchestration tach layers

Earlier notes claimed layers were blocked. That was only true for the
cyclic core (backend<->datamodel, utils<->stages). The boundary
*above* core is clean:

- No module under docling/backend, docling/datamodel, docling/models,
  docling/utils, docling/exceptions, or docling/chunking imports
  anything from docling.pipeline (verified by grep).
- No module anywhere in docling/ imports from docling.cli,
  docling.document_converter, docling.document_extractor, or
  docling.service_client (also verified).

So we can introduce two real layers on top of the cyclic core:

- "pipeline"      — docling.pipeline and all nine concrete pipelines
                     (base, simple, base_extraction, asr, vlm,
                     extraction_vlm, standard_pdf,
                     threaded_standard_pdf, legacy_standard_pdf).
- "orchestration" — docling.cli, docling.document_converter,
                     docling.document_extractor, and
                     docling.experimental.pipeline.

Unlayered modules stay "below" both layers (tach allows them to be
depended on freely) and continue to carry the declared-but-cyclic
backend<->datamodel and utils<->stages edges.

A VLM-only layer was explored but rejected: only
docling.pipeline.vlm_pipeline and docling.pipeline.extraction_vlm_pipeline
could be cleanly layered as "vlm", because the matching datamodel
options (pipeline_options_vlm_model, vlm_engine_options,
vlm_model_specs) and model stages (vlm_convert, vlm_pipeline_models)
sit inside the datamodel/models cycle and cannot be promoted to a
higher layer without first breaking that cycle. Layering only the
two pipeline files is not worth the extra config.

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: expand tach layers to entrypoints/pipeline/models/core

Follow-up to the two-layer attempt. After verifying via grep that
nothing in datamodel/utils/backend imports from
docling.models.{extraction,factories,plugins,vlm_pipeline_models}
or from the "upper" stages (page_assemble, page_preprocessing,
reading_order, picture_description, vlm_convert), those nine
modules can be promoted out of the cyclic core into a dedicated
"models" layer.

The resulting order (highest first):

- entrypoints — cli, document_converter, document_extractor,
                experimental.pipeline
- pipeline    — docling.pipeline + the nine concrete pipelines
- models      — model factories, extraction, plugins,
                vlm_pipeline_models, and the five "upper" stages
- core        — datamodel*, backend*, utils, exceptions, chunking,
                models (base), models.utils, inference_engines.*,
                the six "core stages" that utils cycles with
                (chart_extraction, code_formula, layout, ocr,
                picture_classifier, table_structure), and the
                experimental.* and service_client modules

Rename the previous "orchestration" layer to "entrypoints" to
match the common docling vocabulary. Every module now carries an
explicit layer tag instead of relying on implicit unlayered
behaviour, so future additions must pick a layer deliberately.

A VLM layer, a stand-alone inference-engines layer, and separating
datamodel from backend all remain blocked by the bidirectional
backend<->datamodel and utils<->core-stages edges; those need a
code-level refactor first.

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: refine tach client and foundation layers

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: add optional windows and macos smoke lanes

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: normalize reusable workflow boolean inputs

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: replace external all-green action

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: use org-allowed setup-uv action

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: install compiler toolchain for ML tests

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: bb714afb42cd1b29ab073a7f59cc72874ff2fdcd

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: a1f2761da8f72bfed636bd571ebf77b42c8771b6

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: cc6551b54c5bf4815ae9cd57cf43a98928a74be0

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: b21b0e7ca12b552dbdd54fac1bda113719c286f1

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: simplify ML pytest suite patterns

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: gate heavy examples on label, add job timeouts

- ci-heavy-examples: run only on main push, schedule, workflow_dispatch,
  or when a PR is labeled tests:full / tests:heavy-examples. Drops the
  path-based auto-trigger so that common edits to pyproject.toml,
  uv.lock, or .github/actions do not kick off the 45-60min matrix on
  every PR push. Collapses the changes job into a job-level if gate and
  adds timeout-minutes: 90.
- checks.yml: add timeout-minutes to every job so stuck runners cannot
  burn the full 6h default.

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: tolerate cancelled allowed-skip jobs in check aggregator

Intentional cancellations (manual cancel, concurrency replacement) on
jobs that are already in ALLOWED_SKIPS should not mark the overall
workflow red. Treat `cancelled` the same as `skipped` when the job is
listed as an allowed skip; any unexpected cancellation of a required
job still fails.

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* docs: make minimal vlm example portable

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: 2135051da3ed73d4b8a9130f584f40b56155af1a

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: 4f6d1d7960f7418d0cde6425ae61538da84fda40

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: install workspace packages in CI syncs

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: 492fa9883d4de6d98ebcb40fa863eafe2facff3c

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: 3eefae71643f9ca3df0264690c0c6eb1f67f06f1

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* DCO Remediation Commit for Georg Heiler <georg.kf.heiler@gmail.com>

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: fe8c9689a0ee94f36eb826da8e2177ef87404f5e

I, Georg Heiler <georg.kf.heiler@gmail.com>, hereby add my Signed-off-by to this commit: eabdd24a6734ec873cdaac857718aef2473677e7

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: remove unused graphite concurrency exception

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: document test labels and gate cross-platform lanes

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: select ml tests with pytest markers

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: fix marker selector typing

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: simplify ml suite scheduling

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: mark cross-platform smoke tests

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: reuse test trigger for ml matrix

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: tighten full ci aggregation

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

* ci: share required job result check

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>

---------

Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 14:15:35 +02:00

343 lines
12 KiB
Python

"""
Test MLX Whisper integration for Apple Silicon ASR pipeline.
"""
import sys
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import (
WHISPER_BASE,
WHISPER_BASE_MLX,
WHISPER_LARGE,
WHISPER_LARGE_MLX,
WHISPER_MEDIUM,
WHISPER_SMALL,
WHISPER_TINY,
WHISPER_TURBO,
)
from docling.datamodel.pipeline_options import AsrPipelineOptions
from docling.datamodel.pipeline_options_asr_model import (
InferenceAsrFramework,
InlineAsrMlxWhisperOptions,
)
from docling.pipeline.asr_pipeline import AsrPipeline, _MlxWhisperModel
pytestmark = pytest.mark.ml_asr
class TestMlxWhisperIntegration:
"""Test MLX Whisper model integration."""
def test_mlx_whisper_options_creation(self):
"""Test that MLX Whisper options are created correctly."""
options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
language="en",
task="transcribe",
)
assert options.inference_framework == InferenceAsrFramework.MLX
assert options.repo_id == "mlx-community/whisper-tiny-mlx"
assert options.language == "en"
assert options.task == "transcribe"
assert options.word_timestamps is True
assert AcceleratorDevice.MPS in options.supported_devices
def test_whisper_models_auto_select_mlx(self):
"""Test that Whisper models automatically select MLX when MPS and mlx-whisper are available."""
# This test verifies that the models are correctly configured
# In a real Apple Silicon environment with mlx-whisper installed,
# these models would automatically use MLX
# Check that the models exist and have the correct structure
assert hasattr(WHISPER_TURBO, "inference_framework")
assert hasattr(WHISPER_TURBO, "repo_id")
assert hasattr(WHISPER_BASE, "inference_framework")
assert hasattr(WHISPER_BASE, "repo_id")
assert hasattr(WHISPER_SMALL, "inference_framework")
assert hasattr(WHISPER_SMALL, "repo_id")
def test_explicit_mlx_models_shape(self):
"""Explicit MLX options should have MLX framework and valid repos."""
assert WHISPER_BASE_MLX.inference_framework.name == "MLX"
assert WHISPER_LARGE_MLX.inference_framework.name == "MLX"
assert WHISPER_BASE_MLX.repo_id.startswith("mlx-community/")
def test_model_selectors_mlx_and_native_paths(self, monkeypatch):
"""Cover MLX/native selection branches in asr_model_specs getters."""
from docling.datamodel import asr_model_specs as specs
# Force MLX path
class _Mps:
def is_built(self):
return True
def is_available(self):
return True
class _Torch:
class backends:
mps = _Mps()
monkeypatch.setitem(sys.modules, "torch", _Torch())
monkeypatch.setitem(sys.modules, "mlx_whisper", object())
m_tiny = specs._get_whisper_tiny_model()
m_small = specs._get_whisper_small_model()
m_base = specs._get_whisper_base_model()
m_medium = specs._get_whisper_medium_model()
m_large = specs._get_whisper_large_model()
m_turbo = specs._get_whisper_turbo_model()
assert (
m_tiny.inference_framework == InferenceAsrFramework.MLX
and m_tiny.repo_id.startswith("mlx-community/whisper-tiny")
)
assert (
m_small.inference_framework == InferenceAsrFramework.MLX
and m_small.repo_id.startswith("mlx-community/whisper-small")
)
assert (
m_base.inference_framework == InferenceAsrFramework.MLX
and m_base.repo_id.startswith("mlx-community/whisper-base")
)
assert (
m_medium.inference_framework == InferenceAsrFramework.MLX
and "medium" in m_medium.repo_id
)
assert (
m_large.inference_framework == InferenceAsrFramework.MLX
and "large" in m_large.repo_id
)
assert (
m_turbo.inference_framework == InferenceAsrFramework.MLX
and m_turbo.repo_id.endswith("whisper-turbo")
)
# Force native path (no mlx or no mps)
if "mlx_whisper" in sys.modules:
del sys.modules["mlx_whisper"]
class _MpsOff:
def is_built(self):
return False
def is_available(self):
return False
class _TorchOff:
class backends:
mps = _MpsOff()
monkeypatch.setitem(sys.modules, "torch", _TorchOff())
n_tiny = specs._get_whisper_tiny_model()
n_small = specs._get_whisper_small_model()
n_base = specs._get_whisper_base_model()
n_medium = specs._get_whisper_medium_model()
n_large = specs._get_whisper_large_model()
n_turbo = specs._get_whisper_turbo_model()
assert (
n_tiny.inference_framework == InferenceAsrFramework.WHISPER
and n_tiny.repo_id == "tiny"
)
assert (
n_small.inference_framework == InferenceAsrFramework.WHISPER
and n_small.repo_id == "small"
)
assert (
n_base.inference_framework == InferenceAsrFramework.WHISPER
and n_base.repo_id == "base"
)
assert (
n_medium.inference_framework == InferenceAsrFramework.WHISPER
and n_medium.repo_id == "medium"
)
assert (
n_large.inference_framework == InferenceAsrFramework.WHISPER
and n_large.repo_id == "large"
)
assert (
n_turbo.inference_framework == InferenceAsrFramework.WHISPER
and n_turbo.repo_id == "turbo"
)
def test_selector_import_errors_force_native(self, monkeypatch):
"""If torch import fails, selector must return native."""
from docling.datamodel import asr_model_specs as specs
# Simulate environment where MPS is unavailable and mlx_whisper missing
class _MpsOff:
def is_built(self):
return False
def is_available(self):
return False
class _TorchOff:
class backends:
mps = _MpsOff()
monkeypatch.setitem(sys.modules, "torch", _TorchOff())
if "mlx_whisper" in sys.modules:
del sys.modules["mlx_whisper"]
model = specs._get_whisper_base_model()
assert model.inference_framework == InferenceAsrFramework.WHISPER
@patch("builtins.__import__")
def test_mlx_whisper_model_initialization(self, mock_import):
"""Test MLX Whisper model initialization."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
assert model.enabled is True
assert model.model_path == "mlx-community/whisper-tiny-mlx"
assert model.language == "en"
assert model.task == "transcribe"
assert model.word_timestamps is True
def test_mlx_whisper_model_import_error(self):
"""Test that ImportError is raised when mlx-whisper is not available."""
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
with patch(
"builtins.__import__",
side_effect=ImportError("No module named 'mlx_whisper'"),
):
with pytest.raises(ImportError, match="mlx-whisper is not installed"):
_MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
@patch("builtins.__import__")
def test_mlx_whisper_transcribe(self, mock_import):
"""Test MLX Whisper transcription method."""
# Mock the mlx_whisper module and its transcribe function
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
# Mock the transcribe result
mock_result = {
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "Hello world",
"words": [
{"start": 0.0, "end": 0.5, "word": "Hello"},
{"start": 0.5, "end": 1.0, "word": "world"},
],
}
]
}
mock_mlx_whisper.transcribe.return_value = mock_result
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
model = _MlxWhisperModel(
enabled=True,
artifacts_path=None,
accelerator_options=accelerator_options,
asr_options=asr_options,
)
# Test transcription
audio_path = Path("test_audio.wav")
result = model.transcribe(audio_path)
# Verify the result
assert len(result) == 1
assert result[0].start_time == 0.0
assert result[0].end_time == 2.5
assert result[0].text == "Hello world"
assert len(result[0].words) == 2
assert result[0].words[0].text == "Hello"
assert result[0].words[1].text == "world"
# Verify mlx_whisper.transcribe was called with correct parameters
mock_mlx_whisper.transcribe.assert_called_once_with(
str(audio_path),
path_or_hf_repo="mlx-community/whisper-tiny-mlx",
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
@patch("builtins.__import__")
def test_asr_pipeline_with_mlx_whisper(self, mock_import):
"""Test that AsrPipeline can be initialized with MLX Whisper options."""
# Mock the mlx_whisper import
mock_mlx_whisper = Mock()
mock_import.return_value = mock_mlx_whisper
accelerator_options = AcceleratorOptions(device=AcceleratorDevice.MPS)
asr_options = InlineAsrMlxWhisperOptions(
repo_id="mlx-community/whisper-tiny-mlx",
inference_framework=InferenceAsrFramework.MLX,
language="en",
task="transcribe",
word_timestamps=True,
no_speech_threshold=0.6,
logprob_threshold=-1.0,
compression_ratio_threshold=2.4,
)
pipeline_options = AsrPipelineOptions(
asr_options=asr_options,
accelerator_options=accelerator_options,
)
pipeline = AsrPipeline(pipeline_options)
assert isinstance(pipeline._model, _MlxWhisperModel)
assert pipeline._model.model_path == "mlx-community/whisper-tiny-mlx"