ASR tech debt cleanup: remove dead code, fix bugs, add benchmark script 28/03/2026 (#460)

## Summary

Systematic cleanup of the ASR module addressing tech debt items from
#457. Net reduction of ~430 lines while fixing real bugs and improving
maintainability.

### Bug fixes
- **`enableFP16` silently ignored** —
`optimizedConfiguration(enableFP16:)` delegated to a shared factory that
hardcoded `allowLowPrecisionAccumulationOnGPU = true`, ignoring the
caller's parameter
- **`MLArrayCache.returnArray` only reset float32 data** — cached arrays
of other types (float16, int32) retained stale data from previous use
- **CTC model auto-detection broken** —
`Repo.parakeetCtc110m.folderName` returned `"parakeet-ctc-110m"` instead
of `"parakeet-ctc-110m-coreml"` because the `folderName` switch fell
through to a `default` case that stripped the `-coreml` suffix. Same for
`parakeetCtc06b`.
- **Duplicate tokens at chunk merge boundary** — `mergeByMidpoint` used
`<=`/`>=` so tokens exactly at the cutoff appeared in both left and
right chunks

### Dead code removal
- Deleted `ANEOptimizer` indirection layer (166 lines) — was a
pass-through wrapping `MLModel` with no optimization
- Deleted `PerformanceMonitor` actor and `AggregatedMetrics` — never
instantiated, component times hardcoded to 0
- Deleted `getFloat16Array` from MLArrayCache — never called
- Deleted `sliceEncoderOutput` from AsrTranscription — never called (30
lines)
- Deleted `loadWithANEOptimization` from AsrModels — never called
- Removed unused `tokenTimings` parameter chain through
`processTranscriptionResult`
- Removed unused `import OSLog` / `import CoreML` across 5 files
- Removed `nonisolated(unsafe)` from SlidingWindowAsrManager (types
already Sendable)

### Duplication elimination
- Extracted `clearCachedCtcData()` helper (replaced 3× triple-nil
assignments)
- Extracted `decoderState(for:)` / `setDecoderState(_:for:)` (replaced
4× switch blocks)
- Extracted `frameAlignedAudio()` (replaced 2× duplicated
frame-alignment blocks)
- Added `ASRConstants.secondsPerEncoderFrame` (replaced 5× magic `0.08`)
- Replaced hardcoded `16_000` with `config.sampleRate` /
`ASRConstants.sampleRate`
- Extracted `MLModelConfigurationUtils.defaultConfiguration()` (replaced
5× copy-pasted config methods)
- Extracted `MLModelConfigurationUtils.defaultModelsDirectory()`
(replaced 3× copy-pasted directory methods)
- Consolidated duplicate `vocabularyFile` / `vocabularyFileArray`
constants

### File organization
- Moved `PerformanceMetrics.swift`, `ProgressEmitter.swift`,
`MLArrayCache.swift` from `ASR/Parakeet/` to `Shared/` (used by multiple
modules)
- Renamed `StreamingAudioSourceFactory` → `AudioSourceFactory`,
`StreamingAudioSampleSource` → `AudioSampleSource` (types used by both
ASR and Diarizer)
- Renamed files to match type names: `SortformerDiarizerPipeline.swift`
→ `SortformerDiarizer.swift`, `LSEENDDiarizerAPI.swift` →
`LSEENDDiarizer.swift`, `NemotronPipeline.swift` →
`NemotronStreamingAsrManager+Pipeline.swift`
- Replaced force unwraps in `RnntDecoder.swift` with `guard let` +
descriptive errors
- Removed stale TODO about decoder state in AsrManager

### Benchmark script
- Added `Scripts/run_parakeet_benchmarks.sh` — runs all 6 benchmarks
(v3, v2, TDT-CTC-110M, CTC earnings, EOU 320ms, Nemotron 1120ms) with
WER comparison against `benchmarks100.md` baselines and regression
detection
- Referenced from `Documentation/ASR/benchmarks100.md`

## Verified — no regressions

```
Model                       Baseline    Current      Delta
Parakeet TDT v3 (0.6B)          2.6%      2.64%     +0.04%
Parakeet TDT v2 (0.6B)          3.8%      3.79%     -0.01%
CTC-TDT 110M                    3.6%      3.56%     -0.04%
CTC Earnings                  16.54%     16.51%     -0.03%
EOU 320ms (120M)               7.11%      7.11%     +0.00%
Nemotron 1120ms (0.6B)         1.99%      1.99%     +0.00%
```

## Test plan
- [x] `swift build` passes
- [x] `swift test` passes (all existing tests, updated for removed dead
code)
- [x] All 6 ASR benchmarks match baselines (100 files each)
- [ ] `swift format lint` passes
This commit is contained in:
Alex
2026-03-28 23:44:10 -04:00
committed by GitHub
parent 7f1e006905
commit d9eef864d2
44 changed files with 1410 additions and 1212 deletions
+2
View File
@@ -102,6 +102,8 @@ Resources/
!Sources/FluidAudio/Resources/
!Sources/FluidAudio/Resources/**
scripts/
!Scripts/parakeet_subset_benchmark.sh
!Scripts/diarizer_subset_benchmark.sh
Documentation/parakeet-tdt/
docs/parakeet-tdt/
+12
View File
@@ -2,6 +2,18 @@
Benchmark comparison between `main` and PR #440 (`standardize-asr-directory-structure`) to verify the directory restructuring introduces no regressions.
## Reproduction
All batch TDT and CTC earnings benchmarks can be reproduced with [`Scripts/parakeet_subset_benchmark.sh`](../../Scripts/parakeet_subset_benchmark.sh):
```bash
# Download models and datasets (requires internet)
./Scripts/parakeet_subset_benchmark.sh --download
# Run all 4 benchmarks offline (100 files each, sleep-prevented)
./Scripts/parakeet_subset_benchmark.sh
```
## Environment
- **Hardware**: MacBook Air M2, 16 GB
@@ -0,0 +1,147 @@
# Diarization Benchmarks
Hardware: 2024 MacBook Pro, 48GB RAM, M4 Pro, macOS Tahoe 26.0
Dataset: AMI SDM (Single Distant Microphone), 4-meeting subset — one session per speaker group for diversity.
All results use collar=0.25s, ignoreOverlap=true.
## Summary
| System | Avg DER | Avg RTFx | Mode |
|---|---|---|---|
| LS-EEND (AMI) | 25.7% | 53.9x | Streaming |
| Offline VBx | 21.8% | 97.5x | Offline |
| Streaming 5s/0.8 | 29.9% | 96.2x | Streaming |
| Sortformer (high-lat) | 34.3% | 120.3x | Streaming |
## Offline VBx
Pyannote segmentation + WeSpeaker embeddings + PLDA scoring + VBx clustering.
Default configuration: step ratio 0.2, minSegmentDurationSeconds 1.0, clustering threshold 0.7.
```bash
Scripts/diarizer_subset_benchmark.sh
# or manually:
swift run -c release fluidaudiocli diarization-benchmark --mode offline \
--dataset ami-sdm --auto-download
```
```text
----------------------------------------------------------------------
Meeting DER % Miss % FA % SE % Speakers RTFx
----------------------------------------------------------------------
ES2004a 14.5 7.6 1.7 5.2 5/4 98.2
IS1009a 17.7 3.6 3.0 11.1 6/4 99.1
TS3003a 21.2 11.7 1.4 8.1 2/4 98.4
EN2002a 33.9 4.5 1.4 28.0 4/4 94.2
----------------------------------------------------------------------
AVERAGE 21.8 6.9 1.9 13.1 - 97.5
======================================================================
```
Full VoxConverse results (232 clips): 15.07% DER, 122x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
## Streaming (5s chunks, 0.8 threshold)
Pyannote segmentation + WeSpeaker embeddings + online SpeakerManager clustering.
Best streaming configuration: 5s chunks, 0s overlap, 0.8 clustering threshold.
```bash
Scripts/diarizer_subset_benchmark.sh
# or manually:
swift run -c release fluidaudiocli diarization-benchmark --mode streaming \
--dataset ami-sdm --chunk-seconds 5.0 --overlap-seconds 0.0 \
--threshold 0.8 --auto-download
```
```text
----------------------------------------------------------------------
Meeting DER % Miss % FA % SE % Speakers RTFx
----------------------------------------------------------------------
ES2004a 17.0 9.0 1.3 6.7 7/4 99.2
IS1009a 18.1 4.7 2.7 10.8 4/4 101.0
TS3003a 21.0 12.7 1.4 6.8 2/4 104.3
EN2002a 63.4 9.2 1.1 53.0 7/4 80.1
----------------------------------------------------------------------
AVERAGE 29.9 8.9 1.6 19.3 - 96.2
======================================================================
```
Full 7-meeting results: 26.2% DER, 223x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
EN2002a is a known difficult meeting for the streaming pipeline — aggressive speaker error (53%) due to over-fragmentation.
## Sortformer (NVIDIA High-Latency)
NVIDIA end-to-end Sortformer model, 30.4s chunk config.
Model: [FluidInference/diar-streaming-sortformer-coreml](https://huggingface.co/FluidInference/diar-streaming-sortformer-coreml)
```bash
Scripts/diarizer_subset_benchmark.sh
# or manually:
swift run -c release fluidaudiocli sortformer-benchmark \
--nvidia-high-latency --hf --auto-download
```
```text
----------------------------------------------------------------------
Meeting DER % Miss % FA % SE % Speakers RTFx
----------------------------------------------------------------------
IS1009a 26.5 15.9 1.4 9.3 4/4 122.9
ES2004a 33.4 24.5 0.1 8.8 4/4 117.9
EN2002a 35.7 20.0 0.4 15.2 4/4 121.5
TS3003a 41.8 36.8 0.7 4.3 4/4 119.0
----------------------------------------------------------------------
AVERAGE 34.3 24.3 0.7 9.4 - 120.3
======================================================================
```
Full 16-meeting results: 31.7% DER, 126.7x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
## LS-EEND (AMI variant)
Linear Streaming End-to-End Neural Diarization from Westlake University.
Model: [GradientDescent2718/ls-eend-coreml](https://huggingface.co/GradientDescent2718/ls-eend-coreml)
```bash
Scripts/diarizer_subset_benchmark.sh
# or manually:
swift run -c release fluidaudiocli lseend-benchmark \
--variant ami --auto-download
```
```text
----------------------------------------------------------------------
Meeting DER % Miss % FA % SE % Speakers RTFx
----------------------------------------------------------------------
TS3003a 19.0 16.6 0.8 1.6 4/4 47.5
IS1009a 23.4 8.0 2.6 12.8 4/4 57.7
EN2002a 24.5 19.7 1.1 3.6 4/4 53.2
ES2004a 35.8 13.3 19.2 3.2 4/4 57.2
----------------------------------------------------------------------
AVERAGE 25.7 14.4 5.9 5.3 - 53.9
======================================================================
```
Full 16-meeting results: 20.7% DER, 74.5x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
## Reproducing
Run all 4 systems on the default 4-meeting subset:
```bash
./Scripts/diarizer_subset_benchmark.sh
```
Run on all 16 AMI meetings:
```bash
./Scripts/diarizer_subset_benchmark.sh --all
```
Results are saved to `benchmark_results/` with timestamps. The script uses `caffeinate` to prevent sleep during long runs.
+478
View File
@@ -0,0 +1,478 @@
#!/bin/bash
# Run all diarizer model benchmarks on AMI SDM with sleep prevention.
#
# Benchmarks:
# 1. Offline (VBx) — OfflineDiarizerManager, step=0.2, min-seg=1.0
# 2. Streaming (5s) — DiarizerManager, 5s chunks, 0s overlap, threshold=0.8
# 3. Sortformer — SortformerDiarizer, NVIDIA high-latency config
# 4. LS-EEND — LSEENDDiarizer, AMI variant
#
# Usage:
# ./Scripts/diarizer_subset_benchmark.sh # quick run (4 meetings)
# ./Scripts/diarizer_subset_benchmark.sh --all # full run (all 16 meetings)
# ./Scripts/diarizer_subset_benchmark.sh --max-files 8 # custom subset
# ./Scripts/diarizer_subset_benchmark.sh --download # download missing assets, then exit
#
# The script verifies all models and dataset files exist locally before running.
# If anything is missing it will tell you exactly what and exit (unless --download).
# Uses caffeinate to prevent sleep so you can close the lid.
# Results are saved to benchmark_results/ with timestamps.
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
RESULTS_DIR="$PROJECT_DIR/benchmark_results"
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="$RESULTS_DIR/diarizer_benchmark_${TIMESTAMP}.log"
MODELS_DIR="$HOME/Library/Application Support/FluidAudio/Models"
DATASETS_DIR="$HOME/FluidAudioDatasets"
AMI_SDM_DIR="$DATASETS_DIR/ami_official/sdm"
AMI_RTTM_DIR="$DATASETS_DIR/ami_official/rttm"
MAX_FILES=4 # default: quick 4-meeting subset
# AMI SDM has 16 meetings — this is the standard diarization test set.
# Ordered so the first N picks one from each speaker group for maximum diversity.
# Groups: EN2002 (4 speakers), ES2004 (4), IS1009 (4), TS3003 (4)
ALL_AMI_MEETINGS=(
EN2002a ES2004a IS1009a TS3003a
EN2002b ES2004b IS1009b TS3003b
EN2002c ES2004c IS1009c TS3003c
EN2002d ES2004d IS1009d TS3003d
)
# Parse --all / --max-files <N> from arguments
args=("$@")
for ((i=0; i<${#args[@]}; i++)); do
case "${args[$i]}" in
--all) MAX_FILES=${#ALL_AMI_MEETINGS[@]} ;;
--max-files) MAX_FILES="${args[$((i+1))]}" ; i=$((i+1)) ;;
esac
done
# Select the subset of meetings to run
AMI_MEETINGS=("${ALL_AMI_MEETINGS[@]:0:$MAX_FILES}")
mkdir -p "$RESULTS_DIR"
log() {
echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"
}
# ---------------------------------------------------------------------------
# Verify local assets
# ---------------------------------------------------------------------------
verify_assets() {
local missing=0
# --- AMI SDM audio files ---
local wav_count=0
for meeting in "${AMI_MEETINGS[@]}"; do
if [[ -f "$AMI_SDM_DIR/${meeting}.Mix-Headset.wav" ]]; then
wav_count=$((wav_count + 1))
else
log "MISSING AMI SDM: $AMI_SDM_DIR/${meeting}.Mix-Headset.wav"
missing=1
fi
done
if [[ "$wav_count" -eq 0 ]]; then
log "MISSING AMI SDM: no wav files found in $AMI_SDM_DIR"
missing=1
fi
# --- AMI RTTM annotations (downloaded automatically by --auto-download) ---
local rttm_count=0
for meeting in "${ALL_AMI_MEETINGS[@]}"; do
if [[ -f "$AMI_RTTM_DIR/${meeting}.rttm" ]]; then
rttm_count=$((rttm_count + 1))
fi
done
if [[ "$rttm_count" -eq 0 ]]; then
log "NOTE AMI RTTM annotations not found — will be auto-downloaded by CLI"
fi
# --- Offline diarizer models (pyannote segmentation + wespeaker embedding) ---
local diar_dir="$MODELS_DIR/speaker-diarization-coreml"
if [[ ! -d "$diar_dir" ]]; then
log "MISSING Diarizer models: $diar_dir"
missing=1
fi
# --- Sortformer models (folder may or may not have -coreml suffix) ---
if [[ ! -d "$MODELS_DIR/diar-streaming-sortformer-coreml" ]] && [[ ! -d "$MODELS_DIR/diar-streaming-sortformer" ]]; then
log "MISSING Sortformer models: $MODELS_DIR/diar-streaming-sortformer{,-coreml}"
missing=1
fi
# --- LS-EEND models (folder may or may not have -coreml suffix) ---
if [[ ! -d "$MODELS_DIR/ls-eend-coreml" ]] && [[ ! -d "$MODELS_DIR/ls-eend" ]]; then
log "MISSING LS-EEND models: $MODELS_DIR/ls-eend{,-coreml}"
missing=1
fi
return $missing
}
# ---------------------------------------------------------------------------
# Phase 1: --download (verify first, download only what's missing)
# ---------------------------------------------------------------------------
if [[ "${1:-}" == "--download" ]]; then
log "=== Checking local assets ==="
if verify_assets; then
log "All models and datasets already present locally. Nothing to download."
exit 0
fi
log "Some assets are missing — downloading..."
log "Building release binary..."
cd "$PROJECT_DIR" && swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
log "Downloading AMI SDM dataset + annotations..."
"$CLI" diarization-benchmark --mode offline --auto-download --max-files 1 \
--output "$RESULTS_DIR/warmup_offline.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading Sortformer models..."
"$CLI" sortformer-benchmark --nvidia-high-latency --hf --auto-download --max-files 1 \
--output "$RESULTS_DIR/warmup_sortformer.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading LS-EEND models..."
"$CLI" lseend-benchmark --variant ami --auto-download --max-files 1 \
--output "$RESULTS_DIR/warmup_lseend.json" 2>&1 | tee -a "$LOG_FILE"
rm -f "$RESULTS_DIR"/warmup_*.json
log "=== Downloads complete ==="
exit 0
fi
# ---------------------------------------------------------------------------
# Phase 2: Run benchmarks (offline-safe, sleep-prevented)
# ---------------------------------------------------------------------------
log "=== Verifying local assets before offline run ==="
if ! verify_assets; then
log ""
log "ERROR: Missing assets — cannot run offline."
log "Run with --download first while connected to the internet:"
log " ./Scripts/diarizer_subset_benchmark.sh --download"
exit 1
fi
log "All assets verified locally."
log "=== Diarizer benchmark suite: ${#AMI_MEETINGS[@]}/${#ALL_AMI_MEETINGS[@]} meetings x 4 systems ==="
log "Results directory: $RESULTS_DIR"
cd "$PROJECT_DIR"
# Build release if not already built
if [[ ! -x ".build/release/fluidaudiocli" ]]; then
log "Building release binary..."
swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
fi
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
# caffeinate -s: prevent sleep even on AC power / lid closed
# caffeinate -i: prevent idle sleep
caffeinate -si -w $$ &
CAFFEINATE_PID=$!
log "caffeinate started (PID $CAFFEINATE_PID) — safe to close the lid"
# ---------------------------------------------------------------------------
# Benchmark runners
# ---------------------------------------------------------------------------
# Run a benchmark for each meeting via --single-file, then merge JSON results.
# This ensures we control exactly which meetings run (not the CLI's internal order).
merge_json_results() {
local output_file="$1"
shift
local tmp_files=("$@")
python3 -c "
import json, sys
results = []
for f in sys.argv[2:]:
try:
with open(f) as fh:
data = json.load(fh)
if isinstance(data, list):
results.extend(data)
else:
results.append(data)
except: pass
with open(sys.argv[1], 'w') as out:
json.dump(results, out, indent=2)
" "$output_file" "${tmp_files[@]}" 2>/dev/null
rm -f "${tmp_files[@]}"
}
run_offline_benchmark() {
local label="offline_vbx"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
local tmp_files=()
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, offline VBx) ---"
local start_time=$(date +%s)
for meeting in "${AMI_MEETINGS[@]}"; do
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
tmp_files+=("$tmp")
log " [$label] $meeting"
"$CLI" diarization-benchmark \
--mode offline \
--dataset ami-sdm \
--single-file "$meeting" \
--auto-download \
--output "$tmp" \
2>&1 | tee -a "$LOG_FILE"
done
merge_json_results "$output_file" "${tmp_files[@]}"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_streaming_benchmark() {
local label="streaming_5s"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
local tmp_files=()
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, 5s chunks, threshold=0.8) ---"
local start_time=$(date +%s)
for meeting in "${AMI_MEETINGS[@]}"; do
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
tmp_files+=("$tmp")
log " [$label] $meeting"
"$CLI" diarization-benchmark \
--mode streaming \
--dataset ami-sdm \
--single-file "$meeting" \
--chunk-seconds 5.0 \
--overlap-seconds 0.0 \
--threshold 0.8 \
--auto-download \
--output "$tmp" \
2>&1 | tee -a "$LOG_FILE"
done
merge_json_results "$output_file" "${tmp_files[@]}"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_sortformer_benchmark() {
local label="sortformer"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
local tmp_files=()
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, NVIDIA high-latency) ---"
local start_time=$(date +%s)
for meeting in "${AMI_MEETINGS[@]}"; do
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
tmp_files+=("$tmp")
log " [$label] $meeting"
"$CLI" sortformer-benchmark \
--nvidia-high-latency \
--hf \
--dataset ami \
--single-file "$meeting" \
--auto-download \
--output "$tmp" \
2>&1 | tee -a "$LOG_FILE"
done
merge_json_results "$output_file" "${tmp_files[@]}"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_lseend_benchmark() {
local label="lseend_ami"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
local tmp_files=()
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, AMI variant) ---"
local start_time=$(date +%s)
for meeting in "${AMI_MEETINGS[@]}"; do
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
tmp_files+=("$tmp")
log " [$label] $meeting"
"$CLI" lseend-benchmark \
--variant ami \
--dataset ami \
--single-file "$meeting" \
--auto-download \
--output "$tmp" \
2>&1 | tee -a "$LOG_FILE"
done
merge_json_results "$output_file" "${tmp_files[@]}"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
# ---------------------------------------------------------------------------
# Run all 4 benchmarks
# ---------------------------------------------------------------------------
SUITE_START=$(date +%s)
run_offline_benchmark
run_streaming_benchmark
run_sortformer_benchmark
run_lseend_benchmark
SUITE_END=$(date +%s)
SUITE_ELAPSED=$(( SUITE_END - SUITE_START ))
log "=== All benchmarks complete in ${SUITE_ELAPSED}s ==="
log "Results:"
ls -lh "$RESULTS_DIR"/*_${TIMESTAMP}.json 2>/dev/null | tee -a "$LOG_FILE"
# ---------------------------------------------------------------------------
# Extract DER and RTFx from JSON results
# ---------------------------------------------------------------------------
# Streaming diarization benchmark: JSON is array of per-meeting results with "der" and "rtfx"
extract_streaming_metrics() {
local json_file="$1"
if [[ -f "$json_file" ]]; then
python3 -c "
import json, sys
with open('$json_file') as f:
results = json.load(f)
if not results:
print('N/A N/A')
sys.exit()
avg_der = sum(r['der'] for r in results) / len(results)
avg_rtfx = sum(r['rtfx'] for r in results) / len(results)
print(f'{avg_der:.1f} {avg_rtfx:.1f}')
" 2>/dev/null || echo "N/A N/A"
else
echo "N/A N/A"
fi
}
# Sortformer/LS-EEND: same JSON array format via DiarizationBenchmarkUtils
extract_shared_metrics() {
local json_file="$1"
if [[ -f "$json_file" ]]; then
python3 -c "
import json, sys
with open('$json_file') as f:
results = json.load(f)
if not results:
print('N/A N/A')
sys.exit()
avg_der = sum(r['der'] for r in results) / len(results)
avg_rtfx = sum(r['rtfx'] for r in results) / len(results)
print(f'{avg_der:.1f} {avg_rtfx:.1f}')
" 2>/dev/null || echo "N/A N/A"
else
echo "N/A N/A"
fi
}
# ---------------------------------------------------------------------------
# Compare DER & RTFx against Benchmarks.md baselines
# ---------------------------------------------------------------------------
# Baselines from Documentation/Benchmarks.md (AMI SDM, all 16 meetings)
# Note: when running a subset (--max-files <16), DER will differ from these baselines
# due to per-meeting variance. Baselines are for full 16-meeting runs only.
# Offline: no AMI SDM baseline yet — first --all run establishes it.
# Streaming: 5s/0s/0.8 on AMI SDM (7 meetings) = 26.2% DER, 223.1x RTFx
# Sortformer: NVIDIA high-latency on AMI SDM (16 meetings) = 31.7% DER, 126.7x RTFx
# LS-EEND: AMI variant on AMI SDM (16 meetings) = 20.7% DER, 74.5x RTFx
BASELINE_STREAMING_DER="26.2"
BASELINE_STREAMING_RTFX="223.1"
BASELINE_SORTFORMER_DER="31.7"
BASELINE_SORTFORMER_RTFX="126.7"
BASELINE_LSEEND_DER="20.7"
BASELINE_LSEEND_RTFX="74.5"
OFFLINE_FILE="$RESULTS_DIR/offline_vbx_${TIMESTAMP}.json"
STREAMING_FILE="$RESULTS_DIR/streaming_5s_${TIMESTAMP}.json"
SORTFORMER_FILE="$RESULTS_DIR/sortformer_${TIMESTAMP}.json"
LSEEND_FILE="$RESULTS_DIR/lseend_ami_${TIMESTAMP}.json"
read OFFLINE_DER OFFLINE_RTFX <<< $(extract_streaming_metrics "$OFFLINE_FILE")
read STREAMING_DER STREAMING_RTFX <<< $(extract_streaming_metrics "$STREAMING_FILE")
read SORTFORMER_DER SORTFORMER_RTFX <<< $(extract_shared_metrics "$SORTFORMER_FILE")
read LSEEND_DER LSEEND_RTFX <<< $(extract_shared_metrics "$LSEEND_FILE")
log ""
log "=== DER & RTFx Comparison vs Benchmarks.md baselines (AMI SDM, ${#AMI_MEETINGS[@]} meetings) ==="
log ""
printf "%-25s %12s %12s %12s %12s %12s\n" \
"System" "Base DER" "DER" "Delta" "Base RTFx" "RTFx" | tee -a "$LOG_FILE"
printf "%-25s %12s %12s %12s %12s %12s\n" \
"-------------------------" "------------" "------------" "------------" "------------" "------------" | tee -a "$LOG_FILE"
compare_der_rtfx() {
local label="$1" base_der="$2" current_der="$3" base_rtfx="$4" current_rtfx="$5"
if [[ "$current_der" == "N/A" ]]; then
printf "%-25s %11s%% %12s %12s %11sx %12s\n" \
"$label" "$base_der" "N/A" "—" "$base_rtfx" "N/A" | tee -a "$LOG_FILE"
return
fi
local delta marker=""
delta=$(python3 -c "print(f'{$current_der - $base_der:+.1f}')" 2>/dev/null || echo "?")
local regression
regression=$(python3 -c "print('YES' if $current_der > $base_der + 2.0 else 'NO')" 2>/dev/null || echo "NO")
if [[ "$regression" == "YES" ]]; then
marker=" <- REGRESSION"
fi
printf "%-25s %11s%% %11s%% %11s%% %11sx %11sx%s\n" \
"$label" "$base_der" "$current_der" "$delta" "$base_rtfx" "$current_rtfx" "$marker" | tee -a "$LOG_FILE"
}
# Offline has no AMI SDM baseline yet — show as "new"
if [[ "$OFFLINE_DER" != "N/A" ]]; then
printf "%-25s %12s %11s%% %12s %12s %11sx\n" \
"Offline (VBx)" "—" "$OFFLINE_DER" "(new)" "—" "$OFFLINE_RTFX" | tee -a "$LOG_FILE"
else
printf "%-25s %12s %12s %12s %12s %12s\n" \
"Offline (VBx)" "—" "N/A" "—" "—" "N/A" | tee -a "$LOG_FILE"
fi
compare_der_rtfx "Streaming (5s/0.8)" "$BASELINE_STREAMING_DER" "$STREAMING_DER" "$BASELINE_STREAMING_RTFX" "$STREAMING_RTFX"
compare_der_rtfx "Sortformer (high-lat)" "$BASELINE_SORTFORMER_DER" "$SORTFORMER_DER" "$BASELINE_SORTFORMER_RTFX" "$SORTFORMER_RTFX"
compare_der_rtfx "LS-EEND (AMI)" "$BASELINE_LSEEND_DER" "$LSEEND_DER" "$BASELINE_LSEEND_RTFX" "$LSEEND_RTFX"
log ""
# Check for any DER regressions (>2.0% increase — diarization is noisier than ASR)
ANY_REGRESSION=$(python3 -c "
baselines = [
($BASELINE_STREAMING_DER, '$STREAMING_DER'),
($BASELINE_SORTFORMER_DER, '$SORTFORMER_DER'),
($BASELINE_LSEEND_DER, '$LSEEND_DER'),
]
for b, c in baselines:
if c != 'N/A' and float(c) > b + 2.0:
print('YES'); exit()
print('NO')
" 2>/dev/null || echo "NO")
if [[ "$ANY_REGRESSION" == "YES" ]]; then
log "WARNING: DER REGRESSION DETECTED (>2.0% above baseline) — investigate before merging"
else
log "No DER regressions (all within 2.0% of baseline)"
fi
# caffeinate will exit automatically since the parent process ($$) exits
+398
View File
@@ -0,0 +1,398 @@
#!/bin/bash
# Run all Parakeet model benchmarks (100 files each) with sleep prevention.
#
# Benchmarks:
# 1. ASR v3 — parakeet-tdt-0.6b-v3 on LibriSpeech test-clean
# 2. ASR v2 — parakeet-tdt-0.6b-v2 on LibriSpeech test-clean
# 3. ASR tdt-ctc-110m — parakeet-tdt-ctc-110m on LibriSpeech test-clean
# 4. CTC custom vocab — ctc-earnings-benchmark (v2 TDT + CTC 110m keyword spotting)
# 5. EOU streaming — parakeet-eou 320ms on LibriSpeech test-clean
# 6. Nemotron streaming — nemotron 1120ms on LibriSpeech test-clean
#
# Usage:
# ./Scripts/parakeet_subset_benchmark.sh # verify + run
# ./Scripts/parakeet_subset_benchmark.sh --download # download missing assets, then exit
#
# The script verifies all models and dataset files exist locally before running.
# If anything is missing it will tell you exactly what and exit (unless --download).
# Uses caffeinate to prevent sleep so you can close the lid.
# Results are saved to benchmark_results/ with timestamps.
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
RESULTS_DIR="$PROJECT_DIR/benchmark_results"
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="$RESULTS_DIR/benchmark_${TIMESTAMP}.log"
MAX_FILES=100
SUBSET="test-clean"
MODELS_DIR="$HOME/Library/Application Support/FluidAudio/Models"
DATASETS_DIR="$HOME/Library/Application Support/FluidAudio/Datasets"
EARNINGS_DIR="$HOME/Library/Application Support/FluidAudio/earnings22-kws/test-dataset"
mkdir -p "$RESULTS_DIR"
log() {
echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"
}
# ---------------------------------------------------------------------------
# Verify local assets
# ---------------------------------------------------------------------------
verify_assets() {
local missing=0
# --- Parakeet v3 ---
local v3_dir="$MODELS_DIR/parakeet-tdt-0.6b-v3"
for f in Preprocessor.mlmodelc Encoder.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
if [[ ! -e "$v3_dir/$f" ]]; then
log "MISSING v3: $v3_dir/$f"
missing=1
fi
done
# --- Parakeet v2 (folder may have -coreml suffix) ---
local v2_dir=""
if [[ -d "$MODELS_DIR/parakeet-tdt-0.6b-v2-coreml" ]]; then
v2_dir="$MODELS_DIR/parakeet-tdt-0.6b-v2-coreml"
elif [[ -d "$MODELS_DIR/parakeet-tdt-0.6b-v2" ]]; then
v2_dir="$MODELS_DIR/parakeet-tdt-0.6b-v2"
fi
if [[ -z "$v2_dir" ]]; then
log "MISSING v2: no parakeet-tdt-0.6b-v2* directory found"
missing=1
else
for f in Preprocessor.mlmodelc Encoder.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
if [[ ! -e "$v2_dir/$f" ]]; then
log "MISSING v2: $v2_dir/$f"
missing=1
fi
done
fi
# --- TDT-CTC-110M (fused: no separate Encoder) ---
local tdt_ctc_dir="$MODELS_DIR/parakeet-tdt-ctc-110m"
for f in Preprocessor.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
if [[ ! -e "$tdt_ctc_dir/$f" ]]; then
log "MISSING tdt-ctc-110m: $tdt_ctc_dir/$f"
missing=1
fi
done
# --- CTC 110M model (for custom vocabulary / keyword spotting) ---
local ctc_dir="$MODELS_DIR/parakeet-ctc-110m-coreml"
for f in MelSpectrogram.mlmodelc AudioEncoder.mlmodelc vocab.json; do
if [[ ! -e "$ctc_dir/$f" ]]; then
log "MISSING ctc-110m: $ctc_dir/$f"
missing=1
fi
done
# --- EOU streaming models (320ms chunks) ---
local eou_dir="$MODELS_DIR/parakeet-eou-streaming/320ms"
if [[ ! -d "$eou_dir" ]]; then
log "MISSING eou-320ms: $eou_dir"
missing=1
fi
# --- Nemotron models (uses v3 encoder + nemotron-specific models) ---
# Nemotron reuses the v3 models directory; no separate check needed beyond v3 above.
# --- LibriSpeech test-clean ---
local ls_dir="$DATASETS_DIR/LibriSpeech/$SUBSET"
local trans_count
trans_count=$(find "$ls_dir" -name "*.trans.txt" 2>/dev/null | wc -l | tr -d ' ')
if [[ "$trans_count" -lt 5 ]]; then
log "MISSING LibriSpeech $SUBSET: found $trans_count transcript files (need >= 5)"
missing=1
fi
# --- Earnings22 KWS dataset ---
local earnings_wav_count
earnings_wav_count=$(find "$EARNINGS_DIR" -maxdepth 1 -name "*.wav" 2>/dev/null | wc -l | tr -d ' ')
if [[ "$earnings_wav_count" -lt 10 ]]; then
log "MISSING Earnings22 KWS: found $earnings_wav_count wav files (need >= 10)"
missing=1
fi
return $missing
}
# ---------------------------------------------------------------------------
# Phase 1: --download (verify first, download only what's missing)
# ---------------------------------------------------------------------------
if [[ "${1:-}" == "--download" ]]; then
log "=== Checking local assets ==="
if verify_assets; then
log "All models and datasets already present locally. Nothing to download."
exit 0
fi
log "Some assets are missing — downloading..."
log "Building release binary..."
cd "$PROJECT_DIR" && swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
log "Downloading LibriSpeech $SUBSET dataset..."
"$CLI" download --dataset "librispeech-$SUBSET" 2>&1 | tee -a "$LOG_FILE"
log "Downloading Earnings22 KWS dataset..."
"$CLI" download --dataset earnings22-kws 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading Parakeet v3 models (triggers download if missing)..."
"$CLI" asr-benchmark --model-version v3 --subset "$SUBSET" --max-files 1 \
--output "$RESULTS_DIR/warmup_v3.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading Parakeet v2 models..."
"$CLI" asr-benchmark --model-version v2 --subset "$SUBSET" --max-files 1 \
--output "$RESULTS_DIR/warmup_v2.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading CTC earnings models..."
"$CLI" ctc-earnings-benchmark --max-files 1 --auto-download \
--output "$RESULTS_DIR/warmup_ctc.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading EOU streaming models..."
"$CLI" parakeet-eou --benchmark --chunk-size 320 --max-files 1 \
--output "$RESULTS_DIR/warmup_eou.json" 2>&1 | tee -a "$LOG_FILE"
log "Pre-loading Nemotron streaming models..."
"$CLI" nemotron-benchmark --max-files 1 2>&1 | tee -a "$LOG_FILE"
rm -f "$RESULTS_DIR"/warmup_*.json /tmp/nemotron_*_benchmark.json
log "=== Downloads complete ==="
exit 0
fi
# ---------------------------------------------------------------------------
# Phase 2: Run benchmarks (offline-safe, sleep-prevented)
# ---------------------------------------------------------------------------
log "=== Verifying local assets before offline run ==="
if ! verify_assets; then
log ""
log "ERROR: Missing assets — cannot run offline."
log "Run with --download first while connected to the internet:"
log " ./Scripts/parakeet_subset_benchmark.sh --download"
exit 1
fi
log "All assets verified locally."
log "=== Parakeet benchmark suite: $MAX_FILES files x 6 benchmarks ==="
log "Results directory: $RESULTS_DIR"
cd "$PROJECT_DIR"
# Build release if not already built
if [[ ! -x ".build/release/fluidaudiocli" ]]; then
log "Building release binary..."
swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
fi
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
# caffeinate -s: prevent sleep even on AC power / lid closed
# caffeinate -i: prevent idle sleep
# We wrap the entire benchmark suite so caffeinate dies when the script ends.
caffeinate -si -w $$ &
CAFFEINATE_PID=$!
log "caffeinate started (PID $CAFFEINATE_PID) — safe to close the lid"
run_asr_benchmark() {
local model_version="$1"
local label="$2"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
log "--- $label: starting ($MAX_FILES files, $SUBSET) ---"
local start_time=$(date +%s)
"$CLI" asr-benchmark \
--model-version "$model_version" \
--subset "$SUBSET" \
--max-files "$MAX_FILES" \
--no-auto-download \
--output "$output_file" \
2>&1 | tee -a "$LOG_FILE"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_ctc_earnings_benchmark() {
local label="ctc_earnings_vocab"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
log "--- $label: starting ($MAX_FILES files, v2 TDT + CTC keyword spotting) ---"
local start_time=$(date +%s)
# TDT v2 is used for transcription to match benchmarks100.md baseline
"$CLI" ctc-earnings-benchmark \
--ctc-variant 110m \
--max-files "$MAX_FILES" \
--output "$output_file" \
2>&1 | tee -a "$LOG_FILE"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_eou_benchmark() {
local label="eou_320ms"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
log "--- $label: starting ($MAX_FILES files, $SUBSET, 320ms chunks) ---"
local start_time=$(date +%s)
"$CLI" parakeet-eou \
--benchmark \
--chunk-size 320 \
--max-files "$MAX_FILES" \
--use-cache \
--output "$output_file" \
2>&1 | tee -a "$LOG_FILE"
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
run_nemotron_benchmark() {
local label="nemotron_1120ms"
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
log "--- $label: starting ($MAX_FILES files, $SUBSET, 1120ms chunks) ---"
local start_time=$(date +%s)
"$CLI" nemotron-benchmark \
--max-files "$MAX_FILES" \
2>&1 | tee -a "$LOG_FILE"
# Nemotron writes to /tmp; copy to our results dir
local tmp_file="/tmp/nemotron_1120ms_benchmark.json"
if [[ -f "$tmp_file" ]]; then
cp "$tmp_file" "$output_file"
fi
local end_time=$(date +%s)
local elapsed=$(( end_time - start_time ))
log "--- $label: finished in ${elapsed}s — $output_file ---"
}
SUITE_START=$(date +%s)
run_asr_benchmark "v3" "parakeet_v3"
run_asr_benchmark "v2" "parakeet_v2"
run_asr_benchmark "tdt-ctc-110m" "parakeet_tdt_ctc_110m"
run_ctc_earnings_benchmark
run_eou_benchmark
run_nemotron_benchmark
SUITE_END=$(date +%s)
SUITE_ELAPSED=$(( SUITE_END - SUITE_START ))
log "=== All benchmarks complete in ${SUITE_ELAPSED}s ==="
log "Results:"
ls -lh "$RESULTS_DIR"/*_${TIMESTAMP}.json 2>/dev/null | tee -a "$LOG_FILE"
# ---------------------------------------------------------------------------
# Compare WER against benchmarks100.md baselines
# ---------------------------------------------------------------------------
# Baselines from Documentation/ASR/benchmarks100.md (main column)
BASELINE_V3_WER="2.6"
BASELINE_V2_WER="3.8"
BASELINE_TDT_CTC_WER="3.6"
BASELINE_EARNINGS_WER="16.54"
BASELINE_EOU_WER="7.11"
BASELINE_NEMOTRON_WER="1.99"
extract_wer() {
local json_file="$1"
local field="$2"
if [[ -f "$json_file" ]]; then
python3 -c "import json,sys; d=json.load(open('$json_file')); print(round(d['summary']['$field']*100, 2))" 2>/dev/null || echo "N/A"
else
echo "N/A"
fi
}
# For JSON fields that already store WER as a percentage (not decimal)
extract_wer_pct() {
local json_file="$1"
local section="$2"
local field="$3"
if [[ -f "$json_file" ]]; then
if [[ -n "$section" ]]; then
python3 -c "import json; d=json.load(open('$json_file')); print(round(d['$section']['$field'], 2))" 2>/dev/null || echo "N/A"
else
python3 -c "import json; d=json.load(open('$json_file')); print(round(d['$field'], 2))" 2>/dev/null || echo "N/A"
fi
else
echo "N/A"
fi
}
V3_FILE="$RESULTS_DIR/parakeet_v3_${TIMESTAMP}.json"
V2_FILE="$RESULTS_DIR/parakeet_v2_${TIMESTAMP}.json"
TDT_CTC_FILE="$RESULTS_DIR/parakeet_tdt_ctc_110m_${TIMESTAMP}.json"
EARNINGS_FILE="$RESULTS_DIR/ctc_earnings_vocab_${TIMESTAMP}.json"
EOU_FILE="$RESULTS_DIR/eou_320ms_${TIMESTAMP}.json"
NEMOTRON_FILE="$RESULTS_DIR/nemotron_1120ms_${TIMESTAMP}.json"
V3_WER=$(extract_wer "$V3_FILE" "averageWER")
V2_WER=$(extract_wer "$V2_FILE" "averageWER")
TDT_CTC_WER=$(extract_wer "$TDT_CTC_FILE" "averageWER")
EARNINGS_WER=$(extract_wer_pct "$EARNINGS_FILE" "summary" "avgWer")
EOU_WER=$(extract_wer "$EOU_FILE" "averageWER")
NEMOTRON_WER=$(extract_wer_pct "$NEMOTRON_FILE" "" "wer")
log ""
log "=== WER Comparison vs benchmarks100.md baselines ==="
log ""
printf "%-25s %10s %10s %10s\n" "Model" "Baseline" "Current" "Delta" | tee -a "$LOG_FILE"
printf "%-25s %10s %10s %10s\n" "-------------------------" "----------" "----------" "----------" | tee -a "$LOG_FILE"
compare_wer() {
local label="$1" baseline="$2" current="$3"
if [[ "$current" == "N/A" ]]; then
printf "%-25s %9s%% %10s %10s\n" "$label" "$baseline" "N/A" "—" | tee -a "$LOG_FILE"
return
fi
local delta
delta=$(python3 -c "print(f'{$current - $baseline:+.2f}')" 2>/dev/null || echo "?")
local marker=""
local regression
regression=$(python3 -c "print('YES' if $current > $baseline + 0.3 else 'NO')" 2>/dev/null || echo "NO")
if [[ "$regression" == "YES" ]]; then
marker=" ← REGRESSION"
fi
printf "%-25s %9s%% %9s%% %9s%%%s\n" "$label" "$baseline" "$current" "$delta" "$marker" | tee -a "$LOG_FILE"
}
compare_wer "Parakeet TDT v3 (0.6B)" "$BASELINE_V3_WER" "$V3_WER"
compare_wer "Parakeet TDT v2 (0.6B)" "$BASELINE_V2_WER" "$V2_WER"
compare_wer "CTC-TDT 110M" "$BASELINE_TDT_CTC_WER" "$TDT_CTC_WER"
compare_wer "CTC Earnings" "$BASELINE_EARNINGS_WER" "$EARNINGS_WER"
compare_wer "EOU 320ms (120M)" "$BASELINE_EOU_WER" "$EOU_WER"
compare_wer "Nemotron 1120ms (0.6B)" "$BASELINE_NEMOTRON_WER" "$NEMOTRON_WER"
log ""
# Check for any regressions (>0.3% WER increase)
ANY_REGRESSION=$(python3 -c "
baselines = [($BASELINE_V3_WER, '$V3_WER'), ($BASELINE_V2_WER, '$V2_WER'), ($BASELINE_TDT_CTC_WER, '$TDT_CTC_WER'), ($BASELINE_EARNINGS_WER, '$EARNINGS_WER'), ($BASELINE_EOU_WER, '$EOU_WER'), ($BASELINE_NEMOTRON_WER, '$NEMOTRON_WER')]
for b, c in baselines:
if c != 'N/A' and float(c) > b + 0.3:
print('YES'); exit()
print('NO')
" 2>/dev/null || echo "NO")
if [[ "$ANY_REGRESSION" == "YES" ]]; then
log "⚠ WER REGRESSION DETECTED — investigate before merging"
else
log "✓ No WER regressions (all within 0.3% of baseline)"
fi
# caffeinate will exit automatically since the parent process ($$) exits
@@ -1,166 +0,0 @@
import Accelerate
import CoreML
import Foundation
import Metal
/// Neural Engine optimization utilities for ASR pipeline
public enum ANEOptimizer {
// Use shared ANE constants
public static let aneAlignment = ANEMemoryUtils.aneAlignment
public static let aneTileSize = ANEMemoryUtils.aneTileSize
/// Create ANE-aligned MLMultiArray with optimized memory layout
public static func createANEAlignedArray(
shape: [NSNumber],
dataType: MLMultiArrayDataType
) throws -> MLMultiArray {
do {
return try ANEMemoryUtils.createAlignedArray(
shape: shape,
dataType: dataType,
zeroClear: false // ASR doesn't need zero-cleared memory
)
} catch ANEMemoryUtils.ANEMemoryError.allocationFailed {
throw NSError(
domain: "ANEOptimizer", code: -1,
userInfo: [NSLocalizedDescriptionKey: "Failed to allocate ANE-aligned memory"])
} catch {
throw NSError(
domain: "ANEOptimizer", code: -1,
userInfo: [NSLocalizedDescriptionKey: "ANE memory allocation error: \(error)"])
}
}
/// Calculate optimal strides for ANE tile processing
public static func calculateOptimalStrides(
for shape: [NSNumber],
dataType: MLMultiArrayDataType
) -> [NSNumber] {
return ANEMemoryUtils.calculateOptimalStrides(for: shape)
}
/// Configure optimal compute units for each model type
public static func optimalComputeUnits(for modelType: ModelType) -> MLComputeUnits {
return .cpuAndNeuralEngine
}
/// Create zero-copy memory view between models
public static func createZeroCopyView(
from sourceArray: MLMultiArray,
shape: [NSNumber],
offset: Int = 0
) throws -> MLMultiArray {
// Ensure we have enough data
let sourceElements = sourceArray.shape.map { $0.intValue }.reduce(1, *)
let viewElements = shape.map { $0.intValue }.reduce(1, *)
guard offset + viewElements <= sourceElements else {
throw NSError(
domain: "ANEOptimizer", code: -2,
userInfo: [NSLocalizedDescriptionKey: "View exceeds source array bounds"])
}
// Calculate byte offset
let elementSize = ANEMemoryUtils.getElementSize(for: sourceArray.dataType)
let byteOffset = offset * elementSize
let offsetPointer = sourceArray.dataPointer.advanced(by: byteOffset)
// Create view with same data but new shape
return try MLMultiArray(
dataPointer: offsetPointer,
shape: shape,
dataType: sourceArray.dataType,
strides: calculateOptimalStrides(for: shape, dataType: sourceArray.dataType),
deallocator: nil // No deallocation since it's a view
)
}
/// Prefetch data to Neural Engine
public static func prefetchToNeuralEngine(_ array: MLMultiArray) {
// Trigger ANE prefetch by accessing first and last elements
// This causes the ANE to initiate DMA transfer
if array.count > 0 {
_ = array[0]
_ = array[array.count - 1]
}
}
/// Convert float32 array to float16 for ANE efficiency
public static func convertToFloat16(_ input: MLMultiArray) throws -> MLMultiArray {
guard input.dataType == .float32 else {
throw NSError(
domain: "ANEOptimizer", code: -3,
userInfo: [NSLocalizedDescriptionKey: "Input must be float32"])
}
// Create float16 array with ANE alignment
let float16Array = try createANEAlignedArray(
shape: input.shape,
dataType: .float16
)
// Convert using Accelerate with platform-specific handling
let sourcePtr = input.dataPointer.bindMemory(to: Float.self, capacity: input.count)
var sourceBuffer = vImage_Buffer(
data: sourcePtr,
height: 1,
width: vImagePixelCount(input.count),
rowBytes: input.count * MemoryLayout<Float>.stride
)
// Use UInt16 as storage type for cross-platform compatibility
let destPtr = float16Array.dataPointer.bindMemory(to: UInt16.self, capacity: input.count)
var destBuffer = vImage_Buffer(
data: destPtr,
height: 1,
width: vImagePixelCount(input.count),
rowBytes: input.count * MemoryLayout<UInt16>.stride
)
vImageConvert_PlanarFtoPlanar16F(&sourceBuffer, &destBuffer, 0)
return float16Array
}
/// Model type enumeration for compute unit selection
public enum ModelType {
case encoder
case decoder
case joint
}
}
/// Extension for MLFeatureProvider to enable zero-copy chaining
public class ZeroCopyFeatureProvider: NSObject, MLFeatureProvider {
private let features: [String: MLFeatureValue]
public init(features: [String: MLFeatureValue]) {
self.features = features
super.init()
}
public var featureNames: Set<String> {
Set(features.keys)
}
public func featureValue(for featureName: String) -> MLFeatureValue? {
features[featureName]
}
/// Create a provider that chains output from one model to input of another
public static func chain(
from outputProvider: MLFeatureProvider,
outputName: String,
to inputName: String
) -> ZeroCopyFeatureProvider? {
guard let outputValue = outputProvider.featureValue(for: outputName) else {
return nil
}
return ZeroCopyFeatureProvider(features: [inputName: outputValue])
}
}
+45 -169
View File
@@ -3,11 +3,6 @@ import AVFoundation
import Foundation
import OSLog
public enum AudioSource: Sendable {
case microphone
case system
}
public actor AsrManager {
internal let logger = AppLogger(category: "ASR")
@@ -24,14 +19,12 @@ public actor AsrManager {
internal let progressEmitter = ProgressEmitter()
/// Get the number of decoder layers for the current model.
/// Number of decoder layers for the current model.
/// Returns 2 if models not loaded (v2/v3 default, tdtCtc110m uses 1).
internal func getDecoderLayers() -> Int {
return asrModels?.version.decoderLayers ?? 2
internal var decoderLayerCount: Int {
asrModels?.version.decoderLayers ?? 2
}
/// Token duration optimization model
/// Cached vocabulary loaded once during initialization
internal var vocabulary: [Int: String] = [:]
#if DEBUG
@@ -41,17 +34,25 @@ public actor AsrManager {
}
#endif
// TODO:: the decoder state should be moved higher up in the API interface
// Per-source decoder states are actor-internal; callers reset via resetDecoderState().
internal var microphoneDecoderState: TdtDecoderState
internal var systemDecoderState: TdtDecoderState
// Vocabulary boosting state (configured via configureVocabularyBoosting)
// Internal access required for AsrTranscription extension (separate file)
internal var customVocabulary: CustomVocabularyContext?
internal var ctcSpotter: CtcKeywordSpotter?
internal var vocabularyRescorer: VocabularyRescorer?
internal var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig?
internal var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil }
/// Get decoder state for a given audio source.
internal func decoderState(for source: AudioSource) -> TdtDecoderState {
switch source {
case .microphone: return microphoneDecoderState
case .system: return systemDecoderState
}
}
/// Set decoder state for a given audio source.
internal func setDecoderState(_ state: TdtDecoderState, for source: AudioSource) {
switch source {
case .microphone: microphoneDecoderState = state
case .system: systemDecoderState = state
}
}
// Cached CTC logits from fused Preprocessor (unified custom vocabulary)
internal var cachedCtcLogits: MLMultiArray?
@@ -61,6 +62,13 @@ public actor AsrManager {
/// Whether the Preprocessor outputs CTC logits (unified custom vocabulary model).
public var hasCachedCtcLogits: Bool { cachedCtcLogits != nil }
/// Clear all cached CTC data (logits, frame duration, valid frames).
internal func clearCachedCtcData() {
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
}
/// Get cached CTC raw logits as [[Float]] for external use (e.g. benchmarks).
/// These are raw logits callers must apply `CtcKeywordSpotter.applyLogSoftmax()`
/// to convert to log-probabilities before use in keyword detection.
@@ -120,7 +128,7 @@ public actor AsrManager {
/// Only one session is supported at a time.
public var transcriptionProgressStream: AsyncThrowingStream<Double, Error> {
get async {
await progressEmitter.currentStream()
await progressEmitter.ensureSession()
}
}
@@ -157,55 +165,6 @@ public actor AsrManager {
logger.info("AsrManager initialized successfully with provided models")
}
/// Configure vocabulary boosting for batch transcription.
///
/// When configured, vocabulary terms will be automatically rescored after each `transcribe()` call
/// using CTC-based constrained decoding. The resulting `ASRResult` will have `ctcDetectedTerms`
/// and `ctcAppliedTerms` populated.
///
/// - Parameters:
/// - vocabulary: Custom vocabulary context with terms to detect
/// - ctcModels: Pre-loaded CTC models for keyword spotting
/// - config: Optional rescorer configuration (default: vocabulary-size-aware config)
/// - Throws: Error if rescorer initialization fails
public func configureVocabularyBoosting(
vocabulary: CustomVocabularyContext,
ctcModels: CtcModels,
config: VocabularyRescorer.Config? = nil
) async throws {
self.customVocabulary = vocabulary
let blankId = ctcModels.vocabulary.count
self.ctcSpotter = CtcKeywordSpotter(models: ctcModels, blankId: blankId)
let vocabSize = vocabulary.terms.count
let vocabConfig = ContextBiasingConstants.rescorerConfig(forVocabSize: vocabSize)
self.vocabSizeConfig = vocabConfig
let effectiveConfig = config ?? .default
let ctcModelDir = CtcModels.defaultCacheDirectory(for: ctcModels.variant)
self.vocabularyRescorer = try await VocabularyRescorer.create(
spotter: ctcSpotter!,
vocabulary: vocabulary,
config: effectiveConfig,
ctcModelDirectory: ctcModelDir
)
let isLargeVocab = vocabSize > ContextBiasingConstants.largeVocabThreshold
logger.info(
"Vocabulary boosting configured with \(vocabSize) terms (isLargeVocab: \(isLargeVocab))"
)
}
/// Disable vocabulary boosting and release CTC models.
public func disableVocabularyBoosting() {
customVocabulary = nil
ctcSpotter = nil
vocabularyRescorer = nil
vocabSizeConfig = nil
logger.info("Vocabulary boosting disabled")
}
private func createFeatureProvider(
features: [(name: String, array: MLMultiArray)]
) throws
@@ -275,16 +234,7 @@ public actor AsrManager {
throw ASRError.notInitialized
}
// Get the appropriate decoder state
var state: TdtDecoderState
switch source {
case .microphone:
state = microphoneDecoderState
case .system:
state = systemDecoderState
}
// Reset the existing decoder state to clear all cached values including predictorOutput
var state = decoderState(for: source)
state.reset()
let initDecoderInput = try prepareDecoderInput(
@@ -298,40 +248,7 @@ public actor AsrManager {
)
state.update(from: initDecoderOutput)
// Store back
switch source {
case .microphone:
microphoneDecoderState = state
case .system:
systemDecoderState = state
}
}
private func loadModel(
path: URL,
name: String,
configuration: MLModelConfiguration
) async throws -> MLModel {
do {
let model = try MLModel(contentsOf: path, configuration: configuration)
return model
} catch {
logger.error("Failed to load \(name) model: \(error)")
throw ASRError.modelLoadFailed
}
}
private static func getDefaultModelsDirectory() -> URL {
let applicationSupportURL = FileManager.default.urls(
for: .applicationSupportDirectory, in: .userDomainMask
).first!
let appDirectory = applicationSupportURL.appendingPathComponent(
"FluidAudio", isDirectory: true)
let directory = appDirectory.appendingPathComponent("Models/Parakeet", isDirectory: true)
try? FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true)
return directory.standardizedFileURL
setDecoderState(state, for: source)
}
public func resetState() {
@@ -339,9 +256,7 @@ public actor AsrManager {
let layers = asrModels?.version.decoderLayers ?? 2
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
clearCachedCtcData()
Task { await sharedMLArrayCache.clear() }
}
@@ -356,11 +271,7 @@ public actor AsrManager {
// Reset decoder states using fresh allocations for deterministic behavior
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
// Release vocabulary boosting resources and cached CTC data
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
disableVocabularyBoosting()
clearCachedCtcData()
Task { await sharedMLArrayCache.clear() }
logger.info("AsrManager resources cleaned up")
}
@@ -467,7 +378,7 @@ public actor AsrManager {
let estimatedSamples = Int((Double(audioFile.length) * sampleRateRatio).rounded(.up))
if estimatedSamples > config.streamingThreshold {
return try await transcribeStreaming(url, source: source)
return try await transcribeDiskBacked(url, source: source)
}
}
@@ -476,30 +387,30 @@ public actor AsrManager {
return result
}
/// Transcribe audio from a file URL using streaming mode.
/// Transcribe audio from a file URL using disk-backed chunked processing.
///
/// Memory-efficient transcription that processes audio in chunks, maintaining constant
/// memory usage (~1.2MB) regardless of file size. Ideal for long audio files.
/// Memory-efficient transcription that memory-maps the file and processes audio in chunks,
/// maintaining constant memory usage (~1.2MB) regardless of file size. Ideal for long audio files.
///
/// - Parameters:
/// - url: The URL to the audio file
/// - source: The audio source type (defaults to .system)
/// - Returns: An ASRResult containing the transcribed text and token timings
/// - Throws: ASRError if transcription fails, models are not initialized, or the file cannot be read
public func transcribeStreaming(_ url: URL, source: AudioSource = .system) async throws -> ASRResult {
public func transcribeDiskBacked(_ url: URL, source: AudioSource = .system) async throws -> ASRResult {
guard isAvailable else { throw ASRError.notInitialized }
let startTime = Date()
// Create a disk-backed source for memory-efficient access
let factory = StreamingAudioSourceFactory()
let factory = AudioSourceFactory()
let (sampleSource, _) = try factory.makeDiskBackedSource(
from: url,
targetSampleRate: config.sampleRate
)
let totalSamples = sampleSource.sampleCount
guard totalSamples >= 16_000 else {
guard totalSamples >= config.sampleRate else {
sampleSource.cleanup()
throw ASRError.invalidAudioData
}
@@ -589,53 +500,18 @@ public actor AsrManager {
try await initializeDecoderState(for: source)
}
internal func normalizedTimingToken(_ token: String) -> String {
nonisolated internal func normalizedTimingToken(_ token: String) -> String {
token.replacingOccurrences(of: "", with: " ")
}
internal func convertTokensWithExistingTimings(
_ tokenIds: [Int], timings: [TokenTiming]
) -> (
text: String, timings: [TokenTiming]
) {
guard !tokenIds.isEmpty else { return ("", []) }
/// Decode token IDs to text using SentencePiece conventions.
internal func convertTokensToText(_ tokenIds: [Int]) -> String {
guard !tokenIds.isEmpty else { return "" }
// SentencePiece-compatible decoding algorithm:
// 1. Convert token IDs to token strings
var tokens: [String] = []
var tokenInfos: [(token: String, tokenId: Int, timing: TokenTiming?)] = []
for (index, tokenId) in tokenIds.enumerated() {
if let token = vocabulary[tokenId], !token.isEmpty {
tokens.append(token)
let timing = index < timings.count ? timings[index] : nil
tokenInfos.append((token: token, tokenId: tokenId, timing: timing))
}
}
// 2. Concatenate all tokens (this is how SentencePiece works)
let concatenated = tokens.joined()
// 3. Replace with space (SentencePiece standard)
let text = concatenated.replacingOccurrences(of: "", with: " ")
let tokens = tokenIds.compactMap { vocabulary[$0] }.filter { !$0.isEmpty }
return tokens.joined()
.replacingOccurrences(of: "", with: " ")
.trimmingCharacters(in: .whitespaces)
// 4. For now, return original timings as-is
// Note: Proper timing alignment would require tracking character positions
// through the concatenation and replacement process
let adjustedTimings = tokenInfos.compactMap { info in
info.timing.map { timing in
TokenTiming(
token: normalizedTimingToken(info.token),
tokenId: info.tokenId,
startTime: timing.startTime,
endTime: timing.endTime,
confidence: timing.confidence
)
}
}
return (text, adjustedTimings)
}
nonisolated internal func extractFeatureValue(
@@ -1,6 +1,5 @@
@preconcurrency import CoreML
import Foundation
import OSLog
/// ASR model version enum
public enum AsrModelVersion: Sendable {
@@ -331,8 +330,6 @@ extension AsrModels {
return try await load(from: targetDir, configuration: configuration, progressHandler: progressHandler)
}
/// Load models with ANE-optimized configurations
private static func describeComputeUnits(_ units: MLComputeUnits) -> String {
switch units {
case .cpuOnly:
@@ -348,41 +345,20 @@ extension AsrModels {
}
}
public static func loadWithANEOptimization(
from directory: URL? = nil,
enableFP16: Bool = true
) async throws -> AsrModels {
let targetDir = directory ?? defaultCacheDirectory()
logger.info("Loading ASR models with ANE optimization from: \(targetDir.path)")
// Use the load method that already applies per-model optimizations
return try await load(from: targetDir, configuration: nil)
}
public static func defaultConfiguration() -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = true
// Prefer Neural Engine across platforms for ASR inference to avoid GPU dispatch.
config.computeUnits = .cpuAndNeuralEngine
return config
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .cpuAndNeuralEngine)
}
/// Create optimized configuration for specific model type
/// Create optimized configuration for model inference
public static func optimizedConfiguration(
for modelType: ANEOptimizer.ModelType,
enableFP16: Bool = true
) -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = enableFP16
config.computeUnits = ANEOptimizer.optimalComputeUnits(for: modelType)
// Enable model-specific optimizations
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
if isCI {
config.computeUnits = .cpuOnly
}
let config = MLModelConfigurationUtils.defaultConfiguration(
computeUnits: isCI ? .cpuOnly : .cpuAndNeuralEngine
)
config.allowLowPrecisionAccumulationOnGPU = enableFP16
return config
}
@@ -536,14 +512,7 @@ extension AsrModels {
}
public static func defaultCacheDirectory(for version: AsrModelVersion = .v3) -> URL {
let appSupport = FileManager.default.urls(
for: .applicationSupportDirectory, in: .userDomainMask
).first!
return
appSupport
.appendingPathComponent("FluidAudio", isDirectory: true)
.appendingPathComponent("Models", isDirectory: true)
.appendingPathComponent(version.repo.folderName, isDirectory: true)
MLModelConfigurationUtils.defaultModelsDirectory(for: version.repo)
}
// Legacy method for backward compatibility
@@ -1,6 +1,5 @@
@preconcurrency import CoreML
import Foundation
import OSLog
extension AsrManager {
@@ -8,34 +7,15 @@ extension AsrManager {
_ audioSamples: [Float], source: AudioSource
) async throws -> ASRResult {
guard isAvailable else { throw ASRError.notInitialized }
guard audioSamples.count >= 16_000 else { throw ASRError.invalidAudioData }
guard audioSamples.count >= config.sampleRate else { throw ASRError.invalidAudioData }
let startTime = Date()
// Get the appropriate decoder state
var decoderState: TdtDecoderState
switch source {
case .microphone:
decoderState = microphoneDecoderState
case .system:
decoderState = systemDecoderState
}
var decoderState = decoderState(for: source)
// Route to appropriate processing method based on audio length
if audioSamples.count <= ASRConstants.maxModelSamples {
let originalLength = audioSamples.count
let frameAlignedCandidate =
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
let frameAlignedLength: Int
let alignedSamples: [Float]
if frameAlignedCandidate > originalLength && frameAlignedCandidate <= ASRConstants.maxModelSamples {
frameAlignedLength = frameAlignedCandidate
alignedSamples = audioSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
} else {
frameAlignedLength = originalLength
alignedSamples = audioSamples
}
let (alignedSamples, frameAlignedLength) = frameAlignedAudio(audioSamples)
let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
let (hypothesis, encoderSequenceLength) = try await executeMLInferenceWithTimings(
paddedAudio,
@@ -45,7 +25,7 @@ extension AsrManager {
isLastChunk: true // Single-chunk: always first and last
)
var result = processTranscriptionResult(
let result = processTranscriptionResult(
tokenIds: hypothesis.ySequence,
timestamps: hypothesis.timestamps,
confidences: hypothesis.tokenConfidences,
@@ -55,25 +35,14 @@ extension AsrManager {
processingTime: Date().timeIntervalSince(startTime)
)
// Auto-apply vocabulary rescoring when configured
if vocabBoostingEnabled {
result = await applyVocabularyRescoring(result: result, audioSamples: audioSamples)
}
// Store decoder state back
switch source {
case .microphone:
microphoneDecoderState = decoderState
case .system:
systemDecoderState = decoderState
}
setDecoderState(decoderState, for: source)
return result
}
// ChunkProcessor handles stateless chunked transcription for long audio
let processor = ChunkProcessor(audioSamples: audioSamples)
var result = try await processor.process(
let result = try await processor.process(
using: self,
startTime: startTime,
progressHandler: { [weak self] progress in
@@ -82,18 +51,7 @@ extension AsrManager {
}
)
// Auto-apply vocabulary rescoring when configured
if vocabBoostingEnabled {
result = await applyVocabularyRescoring(result: result, audioSamples: audioSamples)
}
// Store decoder state back (ChunkProcessor uses the stored state directly)
switch source {
case .microphone:
microphoneDecoderState = decoderState
case .system:
systemDecoderState = decoderState
}
setDecoderState(decoderState, for: source)
return result
}
@@ -167,20 +125,14 @@ extension AsrManager {
cachedCtcFrameDuration = 0.04 // 40ms per frame
cachedCtcValidFrames = encoderSequenceLength
} else {
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
clearCachedCtcData()
}
} catch {
logger.warning("CTC head inference failed: \(error.localizedDescription)")
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
clearCachedCtcData()
}
} else {
cachedCtcLogits = nil
cachedCtcFrameDuration = nil
cachedCtcValidFrames = nil
clearCachedCtcData()
}
// Calculate actual audio frames if not provided using shared constants
@@ -250,33 +202,18 @@ extension AsrManager {
return try MLDictionaryFeatureProvider(dictionary: features)
}
/// Streaming-friendly chunk transcription that preserves decoder state and supports start-frame offset.
/// This is used by both sliding window chunking and streaming paths to unify behavior.
public func transcribeStreamingChunk(
/// Chunk transcription that preserves decoder state between calls.
/// Used by SlidingWindowAsrManager for overlapping-window processing with token deduplication.
public func transcribeChunk(
_ chunkSamples: [Float],
source: AudioSource,
previousTokens: [Int] = [],
isLastChunk: Bool = false
) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], encoderSequenceLength: Int) {
// Select and copy decoder state for the source
var state = (source == .microphone) ? microphoneDecoderState : systemDecoderState
var state = decoderState(for: source)
let originalLength = chunkSamples.count
let frameAlignedCandidate =
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
let frameAlignedLength: Int
let alignedSamples: [Float]
if previousTokens.isEmpty
&& frameAlignedCandidate > originalLength
&& frameAlignedCandidate <= ASRConstants.maxModelSamples
{
frameAlignedLength = frameAlignedCandidate
alignedSamples = chunkSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
} else {
frameAlignedLength = originalLength
alignedSamples = chunkSamples
}
let (alignedSamples, frameAlignedLength) = frameAlignedAudio(
chunkSamples, allowAlignment: previousTokens.isEmpty)
let padded = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
let (hypothesis, encLen) = try await executeMLInferenceWithTimings(
padded,
@@ -287,12 +224,7 @@ extension AsrManager {
isLastChunk: isLastChunk
)
// Persist updated state back to the source-specific slot
if source == .microphone {
microphoneDecoderState = state
} else {
systemDecoderState = state
}
setDecoderState(state, for: source)
// Apply token deduplication if previous tokens are provided
if !previousTokens.isEmpty && hypothesis.hasTokens {
@@ -317,23 +249,16 @@ extension AsrManager {
tokenDurations: [Int] = [],
encoderSequenceLength: Int,
audioSamples: [Float],
processingTime: TimeInterval,
tokenTimings: [TokenTiming] = []
processingTime: TimeInterval
) -> ASRResult {
let (text, finalTimings) = convertTokensWithExistingTimings(tokenIds, timings: tokenTimings)
let text = convertTokensToText(tokenIds)
let duration = TimeInterval(audioSamples.count) / TimeInterval(config.sampleRate)
// Convert timestamps to TokenTiming objects if provided
let timingsFromTimestamps = createTokenTimings(
let resultTimings = createTokenTimings(
from: tokenIds, timestamps: timestamps, confidences: confidences, tokenDurations: tokenDurations)
// Use existing timings if provided, otherwise use timings from timestamps
let resultTimings = tokenTimings.isEmpty ? timingsFromTimestamps : finalTimings
// Calculate confidence based on actual model confidence scores from TDT decoder
let confidence = calculateConfidence(
duration: duration,
tokenCount: tokenIds.count,
isEmpty: text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty,
tokenConfidences: confidences
@@ -348,6 +273,27 @@ extension AsrManager {
)
}
/// Align audio samples to encoder frame boundaries by zero-padding to the next frame boundary.
/// Returns the aligned samples and the frame-aligned length.
/// - Parameters:
/// - audioSamples: Raw audio samples
/// - allowAlignment: When false, skip alignment (e.g. when previous context exists)
nonisolated internal func frameAlignedAudio(
_ audioSamples: [Float], allowAlignment: Bool = true
) -> (samples: [Float], frameAlignedLength: Int) {
let originalLength = audioSamples.count
let frameAlignedCandidate =
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
if allowAlignment && frameAlignedCandidate > originalLength
&& frameAlignedCandidate <= ASRConstants.maxModelSamples
{
let aligned = audioSamples + Array(repeating: 0, count: frameAlignedCandidate - originalLength)
return (aligned, frameAlignedCandidate)
}
return (audioSamples, originalLength)
}
nonisolated internal func padAudioIfNeeded(_ audioSamples: [Float], targetLength: Int) -> [Float] {
guard audioSamples.count < targetLength else { return audioSamples }
return audioSamples + Array(repeating: 0, count: targetLength - audioSamples.count)
@@ -356,8 +302,8 @@ extension AsrManager {
/// Calculate confidence score based purely on TDT model token confidence scores
/// Returns the average of token-level softmax probabilities from the decoder
/// Range: 0.1 (empty transcription) to 1.0 (perfect confidence)
private func calculateConfidence(
duration: Double, tokenCount: Int, isEmpty: Bool, tokenConfidences: [Float]
nonisolated private func calculateConfidence(
tokenCount: Int, isEmpty: Bool, tokenConfidences: [Float]
) -> Float {
// Empty transcription gets low confidence
if isEmpty {
@@ -401,27 +347,25 @@ extension AsrManager {
// Sort by timestamp to ensure chronological order
let sortedData = combinedData.sorted { $0.timestamp < $1.timestamp }
let frameDuration = ASRConstants.secondsPerEncoderFrame
for i in 0..<sortedData.count {
let data = sortedData[i]
let tokenId = data.tokenId
let frameIndex = data.timestamp
// Convert encoder frame index to time (80ms per frame)
let startTime = TimeInterval(frameIndex) * 0.08
let startTime = TimeInterval(frameIndex) * frameDuration
// Calculate end time using actual token duration if available
let endTime: TimeInterval
if !tokenDurations.isEmpty && data.duration > 0 {
// Use actual token duration (convert frames to time: duration * 0.08)
let durationInSeconds = TimeInterval(data.duration) * 0.08
endTime = startTime + max(durationInSeconds, 0.08) // Minimum 80ms duration
let durationInSeconds = TimeInterval(data.duration) * frameDuration
endTime = startTime + max(durationInSeconds, frameDuration)
} else if i < sortedData.count - 1 {
// Fallback: Use next token's start time as this token's end time
let nextStartTime = TimeInterval(sortedData[i + 1].timestamp) * 0.08
endTime = max(nextStartTime, startTime + 0.08) // Ensure end > start
let nextStartTime = TimeInterval(sortedData[i + 1].timestamp) * frameDuration
endTime = max(nextStartTime, startTime + frameDuration)
} else {
// Last token: assume minimum duration
endTime = startTime + 0.08
endTime = startTime + frameDuration
}
// Validate that end time is after start time
@@ -447,43 +391,12 @@ extension AsrManager {
return timings
}
/// Slice encoder output to remove left context frames (following NeMo approach)
private func sliceEncoderOutput(
_ encoderOutput: MLMultiArray,
from startFrame: Int,
newLength: Int
) throws -> MLMultiArray {
let shape = encoderOutput.shape
let batchSize = shape[0].intValue
let hiddenSize = shape[2].intValue
// Create new array with sliced dimensions
let slicedArray = try MLMultiArray(
shape: [batchSize, newLength, hiddenSize] as [NSNumber],
dataType: encoderOutput.dataType
)
// Copy data from startFrame onwards
let sourcePtr = encoderOutput.dataPointer.bindMemory(to: Float.self, capacity: encoderOutput.count)
let destPtr = slicedArray.dataPointer.bindMemory(to: Float.self, capacity: slicedArray.count)
for t in 0..<newLength {
for h in 0..<hiddenSize {
let sourceIndex = (startFrame + t) * hiddenSize + h
let destIndex = t * hiddenSize + h
destPtr[destIndex] = sourcePtr[sourceIndex]
}
}
return slicedArray
}
/// Remove duplicate token sequences at the start of the current list that overlap
/// with the tail of the previous accumulated tokens. Returns deduplicated current tokens
/// and the number of removed leading tokens so caller can drop aligned timestamps.
/// Ideally this is not needed. We need to make some more fixes to the TDT decoding logic,
/// this should be a temporary workaround.
internal func removeDuplicateTokenSequence(
nonisolated internal func removeDuplicateTokenSequence(
previous: [Int], current: [Int], maxOverlap: Int = 12
) -> (deduped: [Int], removedCount: Int) {
@@ -551,139 +464,4 @@ extension AsrManager {
return (workingCurrent, removedCount)
}
/// Calculate start frame offset for a sliding window segment (deprecated - now handled by timeJump)
nonisolated internal func calculateStartFrameOffset(segmentIndex: Int, leftContextSeconds: Double) -> Int {
// This method is deprecated as frame tracking is now handled by the decoder's timeJump mechanism
// Kept for test compatibility
return 0
}
// MARK: - Vocabulary Rescoring
/// Apply vocabulary rescoring to an ASRResult using CTC-based constrained decoding.
///
/// Runs CTC inference on the audio samples and applies vocabulary rescoring to correct
/// misrecognized words. Returns an updated ASRResult with rescored text and populated
/// `ctcDetectedTerms`/`ctcAppliedTerms` fields.
///
/// - Parameters:
/// - result: The original ASRResult from transcription
/// - audioSamples: Audio samples used for CTC inference
/// - Returns: An ASRResult with rescored text and CTC metadata, or the original result if rescoring was skipped
internal func applyVocabularyRescoring(
result: ASRResult, audioSamples: [Float]
) async -> ASRResult {
guard let rescorer = vocabularyRescorer,
let vocab = customVocabulary,
let tokenTimings = result.tokenTimings, !tokenTimings.isEmpty
else {
return result
}
do {
// Try to use cached CTC logits from unified Preprocessor first
let logProbs: [[Float]]
let frameDuration: Double
if let cached = cachedCtcLogits, let duration = cachedCtcFrameDuration {
// Convert MLMultiArray to [[Float]]
logProbs = convertCtcLogitsToArray(cached)
frameDuration = duration
logger.debug("Using cached CTC logits from Preprocessor (unified model)")
} else if let spotter = ctcSpotter {
// Fallback: run separate CTC encoder
let spotResult = try await spotter.spotKeywordsWithLogProbs(
audioSamples: audioSamples,
customVocabulary: vocab,
minScore: nil
)
logProbs = spotResult.logProbs
frameDuration = spotResult.frameDuration
logger.debug("Using separate CTC encoder (legacy dual-model approach)")
} else {
logger.warning("Vocabulary rescoring skipped: no CTC logits available")
return result
}
guard !logProbs.isEmpty else {
logger.debug("Vocabulary rescoring skipped: no log probs from CTC")
return result
}
let vocabConfig = vocabSizeConfig ?? ContextBiasingConstants.rescorerConfig(forVocabSize: 0)
// Use the higher of the size-based default and the caller-specified threshold
// so that CustomVocabularyContext.minSimilarity is respected when stricter.
let effectiveMinSimilarity = max(vocabConfig.minSimilarity, vocab.minSimilarity)
let rescoreOutput = rescorer.ctcTokenRescore(
transcript: result.text,
tokenTimings: tokenTimings,
logProbs: logProbs,
frameDuration: frameDuration,
cbw: vocabConfig.cbw,
marginSeconds: 0.5,
minSimilarity: effectiveMinSimilarity
)
guard rescoreOutput.wasModified else {
return result
}
let detected = rescoreOutput.replacements.compactMap { $0.replacementWord }
let applied = rescoreOutput.replacements.filter { $0.shouldReplace }.compactMap {
$0.replacementWord
}
logger.info(
"Vocabulary rescoring applied \(applied.count) replacement(s)"
)
return result.withRescoring(
text: rescoreOutput.text,
detected: detected.isEmpty ? nil : detected,
applied: applied.isEmpty ? nil : applied
)
} catch {
logger.warning("Vocabulary rescoring failed: \(error.localizedDescription)")
return result
}
}
/// Convert CTC logits MLMultiArray to log-probabilities [[Float]] for rescoring.
/// Applies log-softmax with temperature scaling and blank bias to match
/// the processing done in `CtcKeywordSpotter.computeLogProbs`.
private func convertCtcLogitsToArray(_ ctcLogits: MLMultiArray) -> [[Float]] {
// Expected shape: [1, T, V] where T = frames, V = vocab size
let shape = ctcLogits.shape
guard shape.count == 3 else {
logger.warning("Unexpected CTC logits shape: \(shape)")
return []
}
let numFrames = min(shape[1].intValue, cachedCtcValidFrames ?? shape[1].intValue)
let vocabSize = shape[2].intValue
// Extract raw logits
var rawLogits: [[Float]] = []
rawLogits.reserveCapacity(numFrames)
for t in 0..<numFrames {
var frameLogits: [Float] = []
frameLogits.reserveCapacity(vocabSize)
for v in 0..<vocabSize {
let index = [0, t, v] as [NSNumber]
frameLogits.append(ctcLogits[index].floatValue)
}
rawLogits.append(frameLogits)
}
// Apply log-softmax + temperature + blank bias (same as CtcKeywordSpotter.makeLogProbs)
return CtcKeywordSpotter.applyLogSoftmax(
rawLogits: rawLogits,
blankId: ContextBiasingConstants.defaultBlankId
)
}
}
@@ -1,9 +1,7 @@
import CoreML
import Foundation
import OSLog
struct ChunkProcessor {
let sampleSource: StreamingAudioSampleSource
let sampleSource: AudioSampleSource
let totalSamples: Int
private let logger = AppLogger(category: "ChunkProcessor")
@@ -18,7 +16,6 @@ struct ChunkProcessor {
// Stateless chunking aligned with CoreML reference:
// - process ~14.96s of audio per window (frame-aligned) to stay under encoder limit
// - 2.0s overlap (frame-aligned) to give the decoder slack when merging windows
private let sampleRate: Int = 16000
private let overlapSeconds: Double = 2.0
/// Context samples prepended from previous chunk for mel spectrogram stability (80ms = 1 encoder frame).
@@ -36,7 +33,7 @@ struct ChunkProcessor {
return raw / ASRConstants.samplesPerEncoderFrame * ASRConstants.samplesPerEncoderFrame
}
private var overlapSamples: Int {
let requested = Int(overlapSeconds * Double(sampleRate))
let requested = Int(overlapSeconds * Double(ASRConstants.sampleRate))
let capped = min(requested, chunkSamples / 2)
return capped / ASRConstants.samplesPerEncoderFrame * ASRConstants.samplesPerEncoderFrame
}
@@ -46,7 +43,7 @@ struct ChunkProcessor {
}
/// Initialize with a streaming audio sample source for memory-efficient processing.
init(sampleSource: StreamingAudioSampleSource) {
init(sampleSource: AudioSampleSource) {
self.sampleSource = sampleSource
self.totalSamples = sampleSource.sampleCount
}
@@ -66,7 +63,7 @@ struct ChunkProcessor {
var chunkStart = 0
var chunkIndex = 0
var chunkDecoderState = TdtDecoderState.make(
decoderLayers: await manager.getDecoderLayers()
decoderLayers: await manager.decoderLayerCount
)
while chunkStart < totalSamples {
@@ -219,7 +216,7 @@ struct ChunkProcessor {
if left.isEmpty { return right }
if right.isEmpty { return left }
let frameDuration = Double(ASRConstants.samplesPerEncoderFrame) / Double(sampleRate)
let frameDuration = ASRConstants.secondsPerEncoderFrame
let overlapDuration = overlapSeconds
let halfOverlapWindow = overlapDuration / 2
@@ -433,7 +430,7 @@ struct ChunkProcessor {
frameDuration: Double
) -> [TokenWindow] {
let cutoff = (leftEndTime + rightStartTime) / 2
let trimmedLeft = left.filter { Double($0.timestamp) * frameDuration <= cutoff }
let trimmedLeft = left.filter { Double($0.timestamp) * frameDuration < cutoff }
let trimmedRight = right.filter { Double($0.timestamp) * frameDuration >= cutoff }
return trimmedLeft + trimmedRight
}
@@ -32,11 +32,11 @@ struct TdtDecoderState: Sendable {
init(decoderLayers: Int = 2) throws {
// Use ANE-aligned arrays for optimal performance
let decoderHiddenSize = ASRConstants.decoderHiddenSize
hiddenState = try ANEOptimizer.createANEAlignedArray(
hiddenState = try ANEMemoryUtils.createAlignedArray(
shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)],
dataType: .float32
)
cellState = try ANEOptimizer.createANEAlignedArray(
cellState = try ANEMemoryUtils.createAlignedArray(
shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)],
dataType: .float32
)
@@ -175,11 +175,11 @@ internal struct TdtDecoderV3 {
// Preallocate joint input tensors and a reusable provider to avoid per-step allocations.
let encoderHidden = expectedEncoderHidden
let decoderHidden = ASRConstants.decoderHiddenSize
let reusableEncoderStep = try ANEOptimizer.createANEAlignedArray(
let reusableEncoderStep = try ANEMemoryUtils.createAlignedArray(
shape: [1, NSNumber(value: encoderHidden), 1],
dataType: .float32
)
let reusableDecoderStep = try ANEOptimizer.createANEAlignedArray(
let reusableDecoderStep = try ANEMemoryUtils.createAlignedArray(
shape: [1, NSNumber(value: decoderHidden), 1],
dataType: .float32
)
@@ -617,8 +617,8 @@ internal struct TdtDecoderV3 {
try encoderFrames.copyFrame(at: timeIndex, into: encoderDestPtr, destinationStride: encoderDestStride)
// Prefetch arrays for ANE
ANEOptimizer.prefetchToNeuralEngine(encoderStep)
ANEOptimizer.prefetchToNeuralEngine(preparedDecoderStep)
encoderStep.prefetchToNeuralEngine()
preparedDecoderStep.prefetchToNeuralEngine()
// Reuse tiny output tensors for joint prediction (provide raw MLMultiArray backings)
predictionOptions.outputBackings = [
@@ -702,7 +702,7 @@ internal struct TdtDecoderV3 {
}
out = destination
} else {
out = try ANEOptimizer.createANEAlignedArray(
out = try ANEMemoryUtils.createAlignedArray(
shape: [1, NSNumber(value: hiddenSize), 1],
dataType: .float32
)
@@ -829,7 +829,7 @@ internal struct TdtDecoderV3 {
encoderOutput: encoderOutput,
validLength: encoderOutput.count,
expectedHiddenSize: config.encoderHiddenSize)
let encoderStep = try ANEOptimizer.createANEAlignedArray(
let encoderStep = try ANEMemoryUtils.createAlignedArray(
shape: [1, NSNumber(value: encoderFrames.hiddenSize), 1],
dataType: .float32)
let encoderPtr = encoderStep.dataPointer.bindMemory(to: Float.self, capacity: encoderFrames.hiddenSize)
@@ -1,155 +0,0 @@
import Foundation
import MachTaskSelfWrapper
import os
/// Performance metrics for ASR processing
public struct ASRPerformanceMetrics: Codable, Sendable {
public let preprocessorTime: TimeInterval
public let encoderTime: TimeInterval
public let decoderTime: TimeInterval
public let totalProcessingTime: TimeInterval
public let rtfx: Float // Real-time factor
public let peakMemoryMB: Float
public let gpuUtilization: Float?
public var summary: String {
"""
Performance Metrics:
- Preprocessor: \(String(format: "%.3f", preprocessorTime))s
- Encoder: \(String(format: "%.3f", encoderTime))s
- Decoder: \(String(format: "%.3f", decoderTime))s
- Total: \(String(format: "%.3f", totalProcessingTime))s
- RTFx: \(String(format: "%.1f", rtfx))x real-time
- Peak Memory: \(String(format: "%.1f", peakMemoryMB)) MB
- GPU Utilization: \(gpuUtilization.map { String(format: "%.1f%%", $0) } ?? "N/A")
"""
}
}
/// Performance monitor for tracking ASR metrics
public actor PerformanceMonitor {
public init() {}
private let logger = AppLogger(category: "Performance")
private var metrics: [ASRPerformanceMetrics] = []
private let signpostLogger = OSSignposter(subsystem: AppLogger.defaultSubsystem, category: "Performance")
/// Track performance for a processing session
public func trackSession<T: Sendable>(
operation: String,
audioLengthSeconds: Float,
block: @escaping () async throws -> T
) async throws -> (result: T, metrics: ASRPerformanceMetrics) {
let sessionID = signpostLogger.makeSignpostID()
let state = signpostLogger.beginInterval("ASR.Operation", id: sessionID)
let startTime = Date()
let startMemory = getCurrentMemoryUsage()
// Track individual components
let preprocessorTime: TimeInterval = 0
let encoderTime: TimeInterval = 0
let decoderTime: TimeInterval = 0
// Execute the operation
let result = try await block()
let totalTime = Date().timeIntervalSince(startTime)
let peakMemory = max(startMemory, getCurrentMemoryUsage())
let rtfx = audioLengthSeconds / Float(totalTime)
signpostLogger.endInterval("ASR.Operation", state)
let metrics = ASRPerformanceMetrics(
preprocessorTime: preprocessorTime,
encoderTime: encoderTime,
decoderTime: decoderTime,
totalProcessingTime: totalTime,
rtfx: rtfx,
peakMemoryMB: peakMemory,
gpuUtilization: nil // Would require Metal performance counters
)
self.metrics.append(metrics)
logger.info("\(operation) completed: \(metrics.summary)")
return (result, metrics)
}
/// Track a specific component's execution time
public func trackComponent<T: Sendable>(
_ component: String,
block: @escaping () async throws -> T
) async throws -> (result: T, time: TimeInterval) {
let componentID = signpostLogger.makeSignpostID()
let state = signpostLogger.beginInterval("ASR.Component", id: componentID)
let startTime = Date()
let result = try await block()
let time = Date().timeIntervalSince(startTime)
signpostLogger.endInterval("ASR.Component", state)
return (result, time)
}
/// Get aggregated metrics
public func getAggregatedMetrics() -> AggregatedMetrics? {
guard !metrics.isEmpty else { return nil }
let avgRTFx = metrics.map { $0.rtfx }.reduce(0, +) / Float(metrics.count)
let avgProcessingTime = metrics.map { $0.totalProcessingTime }.reduce(0, +) / Double(metrics.count)
let maxMemory = metrics.map { $0.peakMemoryMB }.max() ?? 0
return AggregatedMetrics(
averageRTFx: avgRTFx,
averageProcessingTime: avgProcessingTime,
maxMemoryMB: maxMemory,
sampleCount: metrics.count
)
}
/// Clear all stored metrics
public func reset() {
metrics.removeAll()
}
/// Get current memory usage in MB
private func getCurrentMemoryUsage() -> Float {
var info = mach_task_basic_info()
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size) / 4
let result = withUnsafeMutablePointer(to: &info) {
$0.withMemoryRebound(to: integer_t.self, capacity: 1) {
task_info(
get_current_task_port(),
task_flavor_t(MACH_TASK_BASIC_INFO),
$0,
&count)
}
}
if result == KERN_SUCCESS {
return Float(info.resident_size) / 1024.0 / 1024.0
}
return 0
}
}
/// Aggregated performance metrics
public struct AggregatedMetrics: Sendable {
public let averageRTFx: Float
public let averageProcessingTime: TimeInterval
public let maxMemoryMB: Float
public let sampleCount: Int
public var summary: String {
"""
Aggregated Metrics (\(sampleCount) samples):
- Average RTFx: \(String(format: "%.1f", averageRTFx))x real-time
- Average Processing Time: \(String(format: "%.3f", averageProcessingTime))s
- Max Memory Usage: \(String(format: "%.1f", maxMemoryMB)) MB
"""
}
}
@@ -248,10 +248,7 @@ extension CtcModels {
/// Default CoreML configuration for CTC inference.
public static func defaultConfiguration() -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = true
config.computeUnits = .cpuAndNeuralEngine
return config
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .cpuAndNeuralEngine)
}
/// Check whether required CTC model bundles and vocabulary exist at a directory.
@@ -48,11 +48,9 @@ public actor SlidingWindowAsrManager {
// Vocabulary boosting
// These are initialized via configureVocabularyBoosting() before start()
// CtcKeywordSpotter and VocabularyRescorer contain CoreML models which are not Sendable.
// We manage the safety ourselves by only accessing them from within the actor.
private var customVocabulary: CustomVocabularyContext?
nonisolated(unsafe) private var ctcSpotter: CtcKeywordSpotter?
nonisolated(unsafe) private var vocabularyRescorer: VocabularyRescorer?
private var ctcSpotter: CtcKeywordSpotter?
private var vocabularyRescorer: VocabularyRescorer?
private var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig?
private var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil }
@@ -376,7 +374,7 @@ public actor SlidingWindowAsrManager {
// Start frame offset is now handled by decoder's timeJump mechanism
// Call AsrManager directly with deduplication
let (tokens, timestamps, confidences, _) = try await asrManager.transcribeStreamingChunk(
let (tokens, timestamps, confidences, _) = try await asrManager.transcribeChunk(
windowSamples,
source: audioSource,
previousTokens: accumulatedTokens,
@@ -89,7 +89,10 @@ public final class RnntDecoder {
let decoderInput = try prepareDecoderInput(lastToken: lastToken, h: hState, c: cState)
let decoderOutput = try decoderModel.prediction(from: decoderInput)
var decoderStep = decoderOutput.featureValue(for: "decoder")!.multiArrayValue!
guard let decoderArray = decoderOutput.featureValue(for: "decoder")?.multiArrayValue else {
throw RnntDecoderError.missingOutput("decoder")
}
var decoderStep = decoderArray
// Decoder outputs [1, 640, 2] - NeMo uses the LAST frame
if decoderStep.shape.count == 3 && decoderStep.shape[2].intValue > 1 {
// Slice to keep only the last frame [1, 640, 1]
@@ -106,7 +109,9 @@ public final class RnntDecoder {
// 3. Get Token ID
// Output "token_id" is [1, 1, 1] (argmax)
let tokenIdMultiArray = jointOutput.featureValue(for: "token_id")!.multiArrayValue!
guard let tokenIdMultiArray = jointOutput.featureValue(for: "token_id")?.multiArrayValue else {
throw RnntDecoderError.missingOutput("token_id")
}
let tokenId = tokenIdMultiArray[0].int32Value
if tokenId == blankId {
@@ -120,8 +125,12 @@ public final class RnntDecoder {
lastToken = tokenId
// Update State
let newH = decoderOutput.featureValue(for: "h_out")!.multiArrayValue!
let newC = decoderOutput.featureValue(for: "c_out")!.multiArrayValue!
guard let newH = decoderOutput.featureValue(for: "h_out")?.multiArrayValue else {
throw RnntDecoderError.missingOutput("h_out")
}
guard let newC = decoderOutput.featureValue(for: "c_out")?.multiArrayValue else {
throw RnntDecoderError.missingOutput("c_out")
}
hState = newH
cState = newC
@@ -222,3 +231,14 @@ public final class RnntDecoder {
}
}
enum RnntDecoderError: Error, LocalizedError {
case missingOutput(String)
var errorDescription: String? {
switch self {
case .missingOutput(let name):
return "RNNT decoder missing expected output: \(name)"
}
}
}
@@ -98,21 +98,12 @@ extension DiarizerModels {
}
public static func defaultModelsDirectory() -> URL {
let applicationSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
return
applicationSupport
.appendingPathComponent("FluidAudio", isDirectory: true)
.appendingPathComponent("Models", isDirectory: true)
.appendingPathComponent(Repo.diarizer.folderName, isDirectory: true)
MLModelConfigurationUtils.defaultModelsDirectory(for: .diarizer)
}
static func defaultConfiguration() -> MLModelConfiguration {
let config = MLModelConfiguration()
// Enable Float16 optimization for ~2x speedup
config.allowLowPrecisionAccumulationOnGPU = true
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
config.computeUnits = isCI ? .cpuAndNeuralEngine : .all
return config
return MLModelConfigurationUtils.defaultConfiguration(computeUnits: isCI ? .cpuAndNeuralEngine : .all)
}
}
@@ -100,7 +100,7 @@ public final class OfflineDiarizerManager {
/// - Parameter url: Path to the audio file
/// - Returns: Diarization result with speaker segments
public func process(_ url: URL) async throws -> DiarizationResult {
let factory = StreamingAudioSourceFactory()
let factory = AudioSourceFactory()
let (source, loadDuration) = try factory.makeDiskBackedSource(
from: url,
targetSampleRate: config.segmentation.sampleRate
@@ -114,7 +114,7 @@ public final class OfflineDiarizerManager {
}
public func process(
audioSource: StreamingAudioSampleSource,
audioSource: AudioSampleSource,
audioLoadingSeconds: TimeInterval
) async throws -> DiarizationResult {
try config.validate()
@@ -68,18 +68,11 @@ public struct OfflineDiarizerModels: Sendable {
}
public static func defaultModelsDirectory() -> URL {
let base = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
return
base
.appendingPathComponent("FluidAudio", isDirectory: true)
.appendingPathComponent("Models", isDirectory: true)
MLModelConfigurationUtils.defaultModelsDirectory()
}
private static func defaultConfiguration() -> MLModelConfiguration {
let configuration = MLModelConfiguration()
configuration.allowLowPrecisionAccumulationOnGPU = true
configuration.computeUnits = .all
return configuration
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .all)
}
public static func load(
@@ -179,7 +179,7 @@ struct OfflineEmbeddingExtractor {
}
func extractEmbeddings(
audioSource: StreamingAudioSampleSource,
audioSource: AudioSampleSource,
segmentation: SegmentationOutput
) async throws -> [TimedEmbedding] {
let stream = AsyncThrowingStream<SegmentationChunk, Error> { continuation in
@@ -221,7 +221,7 @@ struct OfflineEmbeddingExtractor {
}
func extractEmbeddings<S: AsyncSequence>(
audioSource: StreamingAudioSampleSource,
audioSource: AudioSampleSource,
segmentationStream: S
) async throws -> [TimedEmbedding] where S.Element == SegmentationChunk {
var embeddings: [TimedEmbedding] = []
@@ -42,7 +42,7 @@ struct OfflineSegmentationProcessor {
}
func process(
audioSource: StreamingAudioSampleSource,
audioSource: AudioSampleSource,
segmentationModel: MLModel,
config: OfflineDiarizerConfig,
chunkHandler: SegmentationChunkHandler? = nil
@@ -91,11 +91,8 @@ extension SortformerModels {
/// Default MLModel configuration
public static func defaultConfiguration() -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = true
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
config.computeUnits = isCI ? .cpuAndNeuralEngine : .all
return config
return MLModelConfigurationUtils.defaultConfiguration(computeUnits: isCI ? .cpuAndNeuralEngine : .all)
}
/// Load Sortformer models from HuggingFace.
+8 -9
View File
@@ -129,6 +129,12 @@ public enum Repo: String, CaseIterable {
return "nemotron-streaming/560ms"
case .sortformer:
return "sortformer"
case .parakeetCtc110m:
return "parakeet-ctc-110m-coreml"
case .parakeetCtc06b:
return "parakeet-ctc-0.6b-coreml"
case .parakeetTdtCtc110m:
return "parakeet-tdt-ctc-110m"
default:
return name.replacingOccurrences(of: "-coreml", with: "")
}
@@ -203,9 +209,6 @@ public enum ModelNames {
jointFile,
]
/// Vocabulary filename for the 110m hybrid TDT-CTC model (JSON array format)
public static let vocabularyFileArray = "parakeet_vocab.json"
/// Required models for fused frontend (110m hybrid: preprocessor contains encoder)
public static let requiredModelsFused: Set<String> = [
preprocessorFile,
@@ -215,12 +218,8 @@ public enum ModelNames {
/// Get vocabulary filename for specific model version
public static func vocabulary(for repo: Repo) -> String {
switch repo {
case .parakeetTdtCtc110m:
return vocabularyFileArray
default:
return vocabularyFile
}
// All Parakeet models use the same vocabulary file (format varies: dict for v2/v3, array for 110m)
return vocabularyFile
}
}
@@ -1,3 +1,4 @@
import Accelerate
import CoreML
import Darwin
import Foundation
@@ -143,6 +144,41 @@ public enum ANEMemoryUtils {
)
}
/// Convert a float32 MLMultiArray to float16 with ANE-aligned memory.
public static func convertToFloat16(_ input: MLMultiArray) throws -> MLMultiArray {
guard input.dataType == .float32 else {
throw ANEMemoryError.unsupportedDataType
}
let float16Array = try createAlignedArray(
shape: input.shape,
dataType: .float16,
zeroClear: false
)
let sourcePtr = input.dataPointer.bindMemory(to: Float.self, capacity: input.count)
var sourceBuffer = vImage_Buffer(
data: sourcePtr,
height: 1,
width: vImagePixelCount(input.count),
rowBytes: input.count * MemoryLayout<Float>.stride
)
let destPtr = float16Array.dataPointer.bindMemory(to: UInt16.self, capacity: input.count)
var destBuffer = vImage_Buffer(
data: destPtr,
height: 1,
width: vImagePixelCount(input.count),
rowBytes: input.count * MemoryLayout<UInt16>.stride
)
vImageConvert_PlanarFtoPlanar16F(&sourceBuffer, &destBuffer, 0)
return float16Array
}
/// Stride-aware copy between two MLMultiArrays that may have different stride layouts.
///
/// Copies all logical elements from `source` to `destination` (which must have the same shape
@@ -27,6 +27,9 @@ public enum ASRConstants {
/// Each encoder frame represents ~80ms of audio at 16kHz
public static let samplesPerEncoderFrame: Int = melHopSize * encoderSubsampling // 1280
/// Duration of one encoder frame in seconds (80ms)
public static let secondsPerEncoderFrame: Double = Double(samplesPerEncoderFrame) / Double(sampleRate) // 0.08
/// WER threshold for detailed error analysis in benchmarks
public static let highWERThreshold: Double = 0.15
@@ -1,6 +1,6 @@
import Foundation
public protocol StreamingAudioSampleSource: Sendable {
public protocol AudioSampleSource: Sendable {
var sampleCount: Int { get }
func copySamples(
into destination: UnsafeMutablePointer<Float>,
@@ -9,7 +9,7 @@ public protocol StreamingAudioSampleSource: Sendable {
) throws
}
public struct ArrayAudioSampleSource: StreamingAudioSampleSource {
public struct ArrayAudioSampleSource: AudioSampleSource {
private let samples: [Float]
public init(samples: [Float]) {
@@ -39,7 +39,7 @@ public struct ArrayAudioSampleSource: StreamingAudioSampleSource {
}
}
public struct DiskBackedAudioSampleSource: StreamingAudioSampleSource {
public struct DiskBackedAudioSampleSource: AudioSampleSource {
private let mappedData: Data
private let floatStride = MemoryLayout<Float>.stride
private let fileURL: URL
@@ -0,0 +1,6 @@
import Foundation
public enum AudioSource: Sendable {
case microphone
case system
}
@@ -3,8 +3,8 @@ import Foundation
import OSLog
import os
public struct StreamingAudioSourceFactory {
private let logger = AppLogger(category: "StreamingAudioSourceFactory")
public struct AudioSourceFactory {
private let logger = AppLogger(category: "AudioSourceFactory")
public init() {}
@@ -26,7 +26,7 @@ public struct StreamingAudioSourceFactory {
let tempURL = try makeTemporaryURL()
guard FileManager.default.createFile(atPath: tempURL.path, contents: nil) else {
throw StreamingAudioError.processingFailed("Failed to create temporary audio buffer at \(tempURL.path)")
throw AudioSourceError.processingFailed("Failed to create temporary audio buffer at \(tempURL.path)")
}
let handle = try FileHandle(forWritingTo: tempURL)
@@ -35,7 +35,7 @@ public struct StreamingAudioSourceFactory {
}
guard let converter = AVAudioConverter(from: inputFormat, to: targetFormat) else {
throw StreamingAudioError.processingFailed(
throw AudioSourceError.processingFailed(
"Unsupported audio format \(inputFormat); failed to create converter")
}
@@ -75,11 +75,11 @@ public struct StreamingAudioSourceFactory {
let duration = Date().timeIntervalSince(startTime)
return (source, duration)
} catch let streamingError as StreamingAudioError {
} catch let streamingError as AudioSourceError {
throw streamingError
} catch {
logger.error("Streaming audio source creation failed: \(error.localizedDescription)")
throw StreamingAudioError.processingFailed(
throw AudioSourceError.processingFailed(
"Streaming audio source creation failed: \(error.localizedDescription)"
)
}
@@ -106,7 +106,7 @@ public struct StreamingAudioSourceFactory {
frameCapacity: inputCapacity
)
else {
throw StreamingAudioError.failedToAllocateBuffer("Input", requestedFrames: Int(inputCapacity))
throw AudioSourceError.failedToAllocateBuffer("Input", requestedFrames: Int(inputCapacity))
}
let estimatedOutputFrames = AVAudioFrameCount(
@@ -118,7 +118,7 @@ public struct StreamingAudioSourceFactory {
frameCapacity: max(1024, estimatedOutputFrames)
)
else {
throw StreamingAudioError.failedToAllocateBuffer("Output", requestedFrames: Int(estimatedOutputFrames))
throw AudioSourceError.failedToAllocateBuffer("Output", requestedFrames: Int(estimatedOutputFrames))
}
var totalSamples = 0
@@ -167,13 +167,13 @@ public struct StreamingAudioSourceFactory {
)
if let conversionError {
throw StreamingAudioError.processingFailed(
throw AudioSourceError.processingFailed(
"Audio conversion failed: \(conversionError.localizedDescription)"
)
}
if let error = readError.withLock({ $0 }) {
throw StreamingAudioError.processingFailed(
throw AudioSourceError.processingFailed(
"Failed while reading audio: \(error.localizedDescription)"
)
}
@@ -181,7 +181,7 @@ public struct StreamingAudioSourceFactory {
let producedFrames = Int(outputBuffer.frameLength)
if producedFrames > 0 {
guard let channelData = outputBuffer.floatChannelData?.pointee else {
throw StreamingAudioError.processingFailed("Missing channel data during conversion")
throw AudioSourceError.processingFailed("Missing channel data during conversion")
}
let byteCount = producedFrames * MemoryLayout<Float>.stride
let baseAddress = UnsafeRawPointer(channelData)
@@ -199,7 +199,7 @@ public struct StreamingAudioSourceFactory {
}
}
public enum StreamingAudioError: Error, LocalizedError {
public enum AudioSourceError: Error, LocalizedError {
case processingFailed(String)
public var errorDescription: String? {
@@ -210,8 +210,8 @@ public enum StreamingAudioError: Error, LocalizedError {
}
}
extension StreamingAudioError {
fileprivate static func failedToAllocateBuffer(_ name: String, requestedFrames: Int) -> StreamingAudioError {
extension AudioSourceError {
fileprivate static func failedToAllocateBuffer(_ name: String, requestedFrames: Int) -> AudioSourceError {
.processingFailed("Failed to allocate \(name.lowercased()) buffer (\(requestedFrames) frames)")
}
}
@@ -1,12 +1,10 @@
import CoreML
import Foundation
import os
/// Thread-safe cache for MLMultiArray instances to reduce allocation overhead
actor MLArrayCache {
private var cache: [CacheKey: [MLMultiArray]] = [:]
private let maxCacheSize: Int
private let logger = AppLogger(category: "MLArrayCache")
struct CacheKey: Hashable {
let shape: [Int]
@@ -24,16 +22,13 @@ actor MLArrayCache {
dataType: dataType
)
// Check if we have a cached array
if var arrays = cache[key], !arrays.isEmpty {
// Never return the same buffer twice while it is still in use; keep the trimmed bucket so we only
// hand out arrays that callers have explicitly returned to the cache.
let array = arrays.removeLast()
cache[key] = arrays
return array
}
return try ANEOptimizer.createANEAlignedArray(shape: shape, dataType: dataType)
return try ANEMemoryUtils.createAlignedArray(shape: shape, dataType: dataType)
}
/// Return an array to the cache for reuse
@@ -47,53 +42,35 @@ actor MLArrayCache {
// Limit cache size per key
if arrays.count < maxCacheSize / max(cache.count, 1) {
// Reset the array data before caching
if array.dataType == .float32 {
array.resetData(to: 0)
}
array.resetData(to: 0)
arrays.append(array)
cache[key] = arrays
logger.debug("Returned array to cache for shape: \(array.shape)")
}
}
/// Pre-warm the cache with commonly used shapes
func prewarm(shapes: [(shape: [NSNumber], dataType: MLMultiArrayDataType)]) async {
logger.info("Pre-warming cache with \(shapes.count) shapes")
func prewarm(shapes: [(shape: [NSNumber], dataType: MLMultiArrayDataType)]) {
for (shape, dataType) in shapes {
do {
var arrays: [MLMultiArray] = []
let prewarmCount = min(5, maxCacheSize / max(shapes.count, 1))
for _ in 0..<prewarmCount {
let array = try ANEOptimizer.createANEAlignedArray(shape: shape, dataType: dataType)
let array = try ANEMemoryUtils.createAlignedArray(shape: shape, dataType: dataType)
arrays.append(array)
}
let key = CacheKey(shape: shape.map { $0.intValue }, dataType: dataType)
cache[key] = arrays
} catch {
logger.error("Failed to pre-warm shape \(shape): \(error)")
// Silently skip shapes that fail to allocate during pre-warm
}
}
}
/// Get a Float16 array (converting from Float32 if needed)
func getFloat16Array(shape: [NSNumber], from float32Array: MLMultiArray? = nil) throws -> MLMultiArray {
if let float32Array = float32Array {
// Convert existing array to Float16
return try ANEOptimizer.convertToFloat16(float32Array)
} else {
// Get new Float16 array from cache
return try getArray(shape: shape, dataType: .float16)
}
}
/// Clear the cache
func clear() {
cache.removeAll()
logger.info("Cache cleared")
}
}
@@ -0,0 +1,36 @@
@preconcurrency import CoreML
import Foundation
/// Shared utilities for creating `MLModelConfiguration` instances and resolving model directories.
public enum MLModelConfigurationUtils {
/// Create a default `MLModelConfiguration` with low-precision GPU accumulation enabled.
///
/// - Parameter computeUnits: Compute units to use (default: `.cpuAndNeuralEngine`).
/// - Returns: Configured `MLModelConfiguration`.
public static func defaultConfiguration(
computeUnits: MLComputeUnits = .cpuAndNeuralEngine
) -> MLModelConfiguration {
let config = MLModelConfiguration()
config.allowLowPrecisionAccumulationOnGPU = true
config.computeUnits = computeUnits
return config
}
/// Default models directory under Application Support.
///
/// - Parameter repo: Optional repository whose `folderName` is appended. When `nil`,
/// returns `~/Library/Application Support/FluidAudio/Models/`.
/// - Returns: URL for the models directory.
public static func defaultModelsDirectory(for repo: Repo? = nil) -> URL {
let base = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
var url =
base
.appendingPathComponent("FluidAudio", isDirectory: true)
.appendingPathComponent("Models", isDirectory: true)
if let repo {
url = url.appendingPathComponent(repo.folderName, isDirectory: true)
}
return url
}
}
@@ -0,0 +1,25 @@
import Foundation
/// Performance metrics for ASR processing
public struct ASRPerformanceMetrics: Codable, Sendable {
public let preprocessorTime: TimeInterval
public let encoderTime: TimeInterval
public let decoderTime: TimeInterval
public let totalProcessingTime: TimeInterval
public let rtfx: Float // Real-time factor
public let peakMemoryMB: Float
public let gpuUtilization: Float?
public var summary: String {
"""
Performance Metrics:
- Preprocessor: \(String(format: "%.3f", preprocessorTime))s
- Encoder: \(String(format: "%.3f", encoderTime))s
- Decoder: \(String(format: "%.3f", decoderTime))s
- Total: \(String(format: "%.3f", totalProcessingTime))s
- RTFx: \(String(format: "%.1f", rtfx))x real-time
- Peak Memory: \(String(format: "%.1f", peakMemoryMB)) MB
- GPU Utilization: \(gpuUtilization.map { String(format: "%.1f%%", $0) } ?? "N/A")
"""
}
}
@@ -7,38 +7,33 @@ actor ProgressEmitter {
init() {}
func ensureSession() async -> AsyncThrowingStream<Double, Error> {
func ensureSession() -> AsyncThrowingStream<Double, Error> {
if let stream = streamStorage {
return stream
}
return await startSession()
return startSession()
}
func currentStream() async -> AsyncThrowingStream<Double, Error> {
await ensureSession()
}
func report(progress: Double) async {
func report(progress: Double) {
guard isActive else { return }
let clamped = min(max(progress, 0.0), 1.0)
continuation?.yield(clamped)
}
func finishSession() async {
guard isActive else {
_ = await ensureSession()
return
}
func finishSession() {
guard isActive else { return }
continuation?.yield(1.0)
continuation?.finish()
reset()
}
func failSession(_ error: Error) async {
func failSession(_ error: Error) {
continuation?.finish(throwing: error)
reset()
}
private func startSession() async -> AsyncThrowingStream<Double, Error> {
private func startSession() -> AsyncThrowingStream<Double, Error> {
if let stream = streamStorage {
return stream
}
@@ -48,22 +43,13 @@ actor ProgressEmitter {
self.continuation = continuation
self.isActive = true
continuation.onTermination =
{ [weak self] (_: AsyncThrowingStream<Double, Error>.Continuation.Termination) in
Task { [weak self] in
guard let self else { return }
await self.resetAndPrepareNextSession()
}
}
continuation.yield(0.0)
return stream
}
private func resetAndPrepareNextSession() async {
private func reset() {
continuation = nil
streamStorage = nil
isActive = false
_ = await startSession()
}
}
@@ -0,0 +1,33 @@
import CoreML
import Foundation
/// Zero-copy MLFeatureProvider for chaining model outputs to inputs.
public class ZeroCopyFeatureProvider: NSObject, MLFeatureProvider {
private let features: [String: MLFeatureValue]
public init(features: [String: MLFeatureValue]) {
self.features = features
super.init()
}
public var featureNames: Set<String> {
Set(features.keys)
}
public func featureValue(for featureName: String) -> MLFeatureValue? {
features[featureName]
}
/// Create a provider that chains output from one model to input of another
public static func chain(
from outputProvider: MLFeatureProvider,
outputName: String,
to inputName: String
) -> ZeroCopyFeatureProvider? {
guard let outputValue = outputProvider.featureValue(for: outputName) else {
return nil
}
return ZeroCopyFeatureProvider(features: [inputName: outputValue])
}
}
@@ -166,7 +166,7 @@ enum ProcessCommand {
// Load and process audio file without materializing the full sample buffer.
let audioURL = URL(fileURLWithPath: audioFile)
let factory = StreamingAudioSourceFactory()
let factory = AudioSourceFactory()
let targetSampleRate = offlineConfig.segmentation.sampleRate
let diskSourceResult = try factory.makeDiskBackedSource(
from: audioURL,
@@ -90,54 +90,6 @@ final class AsrManagerExtensionTests: XCTestCase {
XCTAssertEqual(Array(result.suffix(500)), Array(repeating: 0.0, count: 500))
}
// MARK: - calculateStartFrameOffset Tests
func testCalculateStartFrameOffsetFirstSegment() {
let offset = manager.calculateStartFrameOffset(segmentIndex: 0, leftContextSeconds: 2.0)
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
XCTAssertEqual(offset, 0)
}
func testCalculateStartFrameOffsetSecondSegment() {
let leftContext = 2.0
let offset = manager.calculateStartFrameOffset(segmentIndex: 1, leftContextSeconds: leftContext)
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
XCTAssertEqual(offset, 0)
}
func testCalculateStartFrameOffsetThirdSegment() {
let leftContext = 1.5
let offset = manager.calculateStartFrameOffset(segmentIndex: 2, leftContextSeconds: leftContext)
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
XCTAssertEqual(offset, 0)
}
func testCalculateStartFrameOffsetVariousContexts() {
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
let testCases: [(leftContext: Double, expected: Int)] = [
(0.0, 0), // No context
(0.08, 0), // Method always returns 0
(0.16, 0), // Method always returns 0
(1.0, 0), // Method always returns 0
(3.2, 0), // Method always returns 0
]
for (leftContext, expected) in testCases {
let offset = manager.calculateStartFrameOffset(segmentIndex: 1, leftContextSeconds: leftContext)
XCTAssertEqual(offset, expected, "Failed for leftContext=\(leftContext)")
}
}
func testCalculateStartFrameOffsetNegativeSegment() {
let offset = manager.calculateStartFrameOffset(segmentIndex: -1, leftContextSeconds: 2.0)
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
XCTAssertEqual(offset, 0)
}
// MARK: - Performance Tests
func testPadAudioPerformance() {
@@ -151,11 +103,4 @@ final class AsrManagerExtensionTests: XCTestCase {
}
}
func testCalculateStartFrameOffsetPerformance() {
measure {
for i in 0..<10_000 {
_ = manager.calculateStartFrameOffset(segmentIndex: i % 100, leftContextSeconds: 2.0)
}
}
}
}
@@ -209,34 +209,13 @@ final class AsrModelsTests: XCTestCase {
// In CI environment, all compute units are overridden to .cpuOnly
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
// Test encoder configuration
let melConfig = AsrModels.optimizedConfiguration(for: .encoder)
let config = AsrModels.optimizedConfiguration()
if isCI {
XCTAssertEqual(melConfig.computeUnits, .cpuOnly)
XCTAssertEqual(config.computeUnits, .cpuOnly)
} else {
XCTAssertEqual(melConfig.computeUnits, .cpuAndNeuralEngine)
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
}
XCTAssertTrue(melConfig.allowLowPrecisionAccumulationOnGPU)
// Test decoder configuration
let decoderConfig = AsrModels.optimizedConfiguration(for: .decoder)
if isCI {
XCTAssertEqual(decoderConfig.computeUnits, .cpuOnly)
} else {
XCTAssertEqual(decoderConfig.computeUnits, .cpuAndNeuralEngine)
}
// Test joint configuration
let jointConfig = AsrModels.optimizedConfiguration(for: .joint)
if isCI {
XCTAssertEqual(jointConfig.computeUnits, .cpuOnly)
} else {
XCTAssertEqual(jointConfig.computeUnits, .cpuAndNeuralEngine)
}
// Test with FP16 disabled
let fp32Config = AsrModels.optimizedConfiguration(for: .encoder, enableFP16: false)
XCTAssertFalse(fp32Config.allowLowPrecisionAccumulationOnGPU)
XCTAssertTrue(config.allowLowPrecisionAccumulationOnGPU)
}
func testOptimizedConfigurationCIEnvironment() {
@@ -251,7 +230,7 @@ final class AsrModelsTests: XCTestCase {
}
}
let config = AsrModels.optimizedConfiguration(for: .encoder)
let config = AsrModels.optimizedConfiguration()
XCTAssertEqual(config.computeUnits, .cpuOnly)
}
@@ -288,22 +267,10 @@ final class AsrModelsTests: XCTestCase {
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
}
func testOptimalComputeUnitsRespectsPlatform() {
// Test each model type
let modelTypes: [ANEOptimizer.ModelType] = [
.encoder,
.decoder,
.joint,
]
for modelType in modelTypes {
let computeUnits = ANEOptimizer.optimalComputeUnits(for: modelType)
// All models should use CPU+ANE for optimal performance
XCTAssertEqual(
computeUnits, .cpuAndNeuralEngine,
"Model type \(modelType) should use CPU+ANE")
}
func testOptimalComputeUnitsDefault() {
// Default configuration uses CPU+ANE for optimal performance
let config = AsrModels.defaultConfiguration()
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
}
// MARK: - TDT-CTC-110M Model Version Tests
@@ -384,8 +351,8 @@ final class AsrModelsTests: XCTestCase {
}
func testTdtCtc110mVocabularyFilename() {
// tdtCtc110m uses parakeet_vocab.json (array format)
let vocabFile = ModelNames.ASR.vocabularyFileArray
// tdtCtc110m uses parakeet_vocab.json (array format, same filename as v2/v3)
let vocabFile = ModelNames.ASR.vocabularyFile
XCTAssertEqual(vocabFile, "parakeet_vocab.json")
// Verify it has .json extension
@@ -101,27 +101,25 @@ final class AsrTranscriptionTests: XCTestCase {
XCTAssertTrue(result.tokenTimings?.isEmpty == true) // No timestamps provided, should be empty array
}
func testProcessTranscriptionResultWithTimings() async {
func testProcessTranscriptionResultWithTimestampsAndConfidences() async {
await setupMockVocabulary()
let tokenIds = [10, 20, 30]
let audioSamples = Array(repeating: Float(0), count: 48_000) // 3 seconds
let timings = [
TokenTiming(token: "hello", tokenId: 10, startTime: 0.0, endTime: 1.0, confidence: 0.9),
TokenTiming(token: "world", tokenId: 20, startTime: 1.0, endTime: 2.0, confidence: 0.85),
TokenTiming(token: "test", tokenId: 30, startTime: 2.0, endTime: 3.0, confidence: 0.95),
]
let timestamps = [0, 12, 25]
let confidences: [Float] = [0.9, 0.85, 0.95]
let result = await manager.processTranscriptionResult(
tokenIds: tokenIds,
timestamps: timestamps,
confidences: confidences,
encoderSequenceLength: 150,
audioSamples: audioSamples,
processingTime: 1.2,
tokenTimings: timings
processingTime: 1.2
)
XCTAssertEqual(result.duration, 3.0, accuracy: 0.01)
XCTAssertNotNil(result.tokenTimings)
// Note: Actual timing count may differ due to convertTokensWithExistingTimings filtering
XCTAssertEqual(result.tokenTimings?.count, 3)
}
func testProcessTranscriptionResultWithTimestamps() async {
@@ -133,10 +133,10 @@ final class ModelNamesTests: XCTestCase {
}
func testParakeetTdtCtc110mVocabulary() {
// tdtCtc110m uses array-format vocabulary
// tdtCtc110m uses same vocabulary file (array-format JSON, parsed at load time)
let vocabFile = ModelNames.ASR.vocabulary(for: .parakeetTdtCtc110m)
XCTAssertEqual(vocabFile, "parakeet_vocab.json")
XCTAssertEqual(vocabFile, ModelNames.ASR.vocabularyFileArray)
XCTAssertEqual(vocabFile, ModelNames.ASR.vocabularyFile)
}
func testParakeetTdtCtc110mUsesRequiredModelsFused() {
@@ -42,81 +42,4 @@ final class PerformanceMetricsTests: XCTestCase {
let summary = metrics.summary
XCTAssertTrue(summary.contains("N/A"), "Summary should show N/A for nil GPU utilization")
}
// MARK: - AggregatedMetrics
func testAggregatedMetricsSummaryFormatting() {
let aggregated = AggregatedMetrics(
averageRTFx: 8.5,
averageProcessingTime: 1.234,
maxMemoryMB: 512.0,
sampleCount: 10
)
let summary = aggregated.summary
XCTAssertTrue(summary.contains("10 samples"), "Summary should contain sample count")
XCTAssertTrue(summary.contains("8.5"), "Summary should contain average RTFx")
XCTAssertTrue(summary.contains("1.234"), "Summary should contain average processing time")
XCTAssertTrue(summary.contains("512.0"), "Summary should contain max memory")
}
// MARK: - PerformanceMonitor
func testAggregatedMetricsEmptyReturnsNil() async {
let monitor = PerformanceMonitor()
let result = await monitor.getAggregatedMetrics()
XCTAssertNil(result, "Empty monitor should return nil for aggregated metrics")
}
func testResetClearsMetrics() async throws {
let monitor = PerformanceMonitor()
// Track a session to add metrics
_ = try await monitor.trackSession(operation: "test", audioLengthSeconds: 1.0) {
return 42
}
// Verify metrics exist
let before = await monitor.getAggregatedMetrics()
XCTAssertNotNil(before)
// Reset and verify empty
await monitor.reset()
let after = await monitor.getAggregatedMetrics()
XCTAssertNil(after, "After reset, aggregated metrics should be nil")
}
func testTrackSessionReturnsMetrics() async throws {
let monitor = PerformanceMonitor()
let (result, metrics) = try await monitor.trackSession(
operation: "test",
audioLengthSeconds: 2.0
) {
return "hello"
}
XCTAssertEqual(result, "hello")
XCTAssertGreaterThanOrEqual(metrics.totalProcessingTime, 0)
XCTAssertGreaterThan(metrics.rtfx, 0)
}
func testAggregatedMetricsComputation() async throws {
let monitor = PerformanceMonitor()
for i in 0..<3 {
_ = try await monitor.trackSession(
operation: "test\(i)",
audioLengthSeconds: Float(i + 1)
) {
return i
}
}
let aggregated = await monitor.getAggregatedMetrics()
XCTAssertNotNil(aggregated)
XCTAssertEqual(aggregated?.sampleCount, 3)
XCTAssertGreaterThan(aggregated!.averageRTFx, 0)
XCTAssertGreaterThan(aggregated!.averageProcessingTime, 0)
}
}
@@ -10,7 +10,7 @@ final class ANEOptimizerTests: XCTestCase {
func testCreateANEAlignedArrayFloat32() throws {
let shape: [NSNumber] = [1, 100]
let array = try ANEOptimizer.createANEAlignedArray(
let array = try ANEMemoryUtils.createAlignedArray(
shape: shape,
dataType: .float32
)
@@ -22,7 +22,7 @@ final class ANEOptimizerTests: XCTestCase {
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
if !isCI {
// Verify memory alignment only in non-CI environment
let alignment = ANEOptimizer.aneAlignment
let alignment = ANEMemoryUtils.aneAlignment
let pointerValue = Int(bitPattern: array.dataPointer)
XCTAssertEqual(pointerValue % alignment, 0, "Array should be \(alignment)-byte aligned")
}
@@ -30,7 +30,7 @@ final class ANEOptimizerTests: XCTestCase {
func testCreateANEAlignedArrayFloat16() throws {
let shape: [NSNumber] = [1, 64] // Smaller shape for CI stability
let array = try ANEOptimizer.createANEAlignedArray(
let array = try ANEMemoryUtils.createAlignedArray(
shape: shape,
dataType: .float16
)
@@ -46,7 +46,7 @@ final class ANEOptimizerTests: XCTestCase {
// Verify memory alignment only in non-CI environment
let pointerValue = Int(bitPattern: array.dataPointer)
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
}
}
@@ -56,9 +56,8 @@ final class ANEOptimizerTests: XCTestCase {
func testCalculateOptimalStridesBasic() {
let shape: [NSNumber] = [1, 3, 224, 224]
let strides = ANEOptimizer.calculateOptimalStrides(
for: shape,
dataType: .float32
let strides = ANEMemoryUtils.calculateOptimalStrides(
for: shape
)
XCTAssertEqual(strides.count, shape.count)
@@ -71,9 +70,8 @@ final class ANEOptimizerTests: XCTestCase {
func testCalculateOptimalStridesWithPadding() {
// Test with dimension that needs padding (not multiple of 16)
let shape: [NSNumber] = [1, 100] // 100 is not multiple of 16
let strides = ANEOptimizer.calculateOptimalStrides(
for: shape,
dataType: .float32
let strides = ANEMemoryUtils.calculateOptimalStrides(
for: shape
)
// The stride for the first dimension should account for padding
@@ -84,22 +82,10 @@ final class ANEOptimizerTests: XCTestCase {
// MARK: - Compute Unit Selection Tests
func testOptimalComputeUnits() {
// All models use CPU+ANE for optimal performance
XCTAssertEqual(
ANEOptimizer.optimalComputeUnits(for: .encoder),
.cpuAndNeuralEngine
)
XCTAssertEqual(
ANEOptimizer.optimalComputeUnits(for: .decoder),
.cpuAndNeuralEngine
)
XCTAssertEqual(
ANEOptimizer.optimalComputeUnits(for: .joint),
.cpuAndNeuralEngine
)
func testDefaultConfigurationComputeUnits() {
// Default configuration uses CPU+ANE
let config = MLModelConfigurationUtils.defaultConfiguration()
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
}
// MARK: - Zero-Copy View Tests (Removed - causes crashes with memory operations)
@@ -115,7 +101,7 @@ final class ANEOptimizerTests: XCTestCase {
float32Array[i] = NSNumber(value: Float(i) * 0.1)
}
let result = try ANEOptimizer.convertToFloat16(float32Array)
let result = try ANEMemoryUtils.convertToFloat16(float32Array)
XCTAssertEqual(result.shape, float32Array.shape)
@@ -129,7 +115,7 @@ final class ANEOptimizerTests: XCTestCase {
// Verify ANE alignment only in non-CI environment
let pointerValue = Int(bitPattern: result.dataPointer)
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
}
// Verify data conversion accuracy (regardless of CI)
@@ -145,11 +131,9 @@ final class ANEOptimizerTests: XCTestCase {
let int32Array = try MLMultiArray(shape: [5], dataType: .int32)
XCTAssertThrowsError(
try ANEOptimizer.convertToFloat16(int32Array)
try ANEMemoryUtils.convertToFloat16(int32Array)
) { error in
let nsError = error as NSError
XCTAssertEqual(nsError.domain, "ANEOptimizer")
XCTAssertEqual(nsError.code, -3)
XCTAssertTrue(error is ANEMemoryUtils.ANEMemoryError)
}
}
@@ -26,7 +26,7 @@ final class MLArrayCacheTests: XCTestCase {
if !isCI {
// Verify ANE alignment only in non-CI environment
let pointerValue = Int(bitPattern: array.dataPointer)
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
}
}
@@ -97,58 +97,6 @@ final class MLArrayCacheTests: XCTestCase {
XCTAssertNotNil(finalArray)
}
// MARK: - Float16 Support
func testGetFloat16ArrayFromScratch() async throws {
let shape: [NSNumber] = [1, 64] // Smaller for CI stability
let fp16Array = try await cache.getFloat16Array(shape: shape)
XCTAssertEqual(fp16Array.shape, shape)
// In CI, we might get Float32 instead of Float16 for stability
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
if isCI {
// In CI, accept either Float16 or Float32
XCTAssertTrue(fp16Array.dataType == .float16 || fp16Array.dataType == .float32)
} else {
XCTAssertEqual(fp16Array.dataType, .float16)
// Verify ANE alignment only in non-CI environment
let pointerValue = Int(bitPattern: fp16Array.dataPointer)
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
}
}
func testGetFloat16ArrayFromFloat32() async throws {
// Create Float32 array
let shape: [NSNumber] = [50] // Smaller for CI stability
let float32Array = try MLMultiArray(shape: shape, dataType: .float32)
// Fill with test values
for i in 0..<float32Array.count {
float32Array[i] = NSNumber(value: Float(i) * 0.1)
}
// Convert to Float16
let float16Array = try await cache.getFloat16Array(shape: shape, from: float32Array)
XCTAssertEqual(float16Array.shape, shape)
// In CI, we might get Float32 instead of Float16 for stability
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
if isCI {
// In CI, accept either Float16 or Float32
XCTAssertTrue(float16Array.dataType == .float16 || float16Array.dataType == .float32)
} else {
XCTAssertEqual(float16Array.dataType, .float16)
}
// Verify conversion accuracy (regardless of CI)
for i in 0..<min(5, float16Array.count) {
XCTAssertEqual(float16Array[i].floatValue, Float(i) * 0.1, accuracy: 0.01)
}
}
// MARK: - Pre-warming Tests
func testPrewarmCache() async {