mirror of
https://github.com/FluidInference/FluidAudio.git
synced 2026-05-12 20:20:36 +00:00
ASR tech debt cleanup: remove dead code, fix bugs, add benchmark script 28/03/2026 (#460)
## Summary Systematic cleanup of the ASR module addressing tech debt items from #457. Net reduction of ~430 lines while fixing real bugs and improving maintainability. ### Bug fixes - **`enableFP16` silently ignored** — `optimizedConfiguration(enableFP16:)` delegated to a shared factory that hardcoded `allowLowPrecisionAccumulationOnGPU = true`, ignoring the caller's parameter - **`MLArrayCache.returnArray` only reset float32 data** — cached arrays of other types (float16, int32) retained stale data from previous use - **CTC model auto-detection broken** — `Repo.parakeetCtc110m.folderName` returned `"parakeet-ctc-110m"` instead of `"parakeet-ctc-110m-coreml"` because the `folderName` switch fell through to a `default` case that stripped the `-coreml` suffix. Same for `parakeetCtc06b`. - **Duplicate tokens at chunk merge boundary** — `mergeByMidpoint` used `<=`/`>=` so tokens exactly at the cutoff appeared in both left and right chunks ### Dead code removal - Deleted `ANEOptimizer` indirection layer (166 lines) — was a pass-through wrapping `MLModel` with no optimization - Deleted `PerformanceMonitor` actor and `AggregatedMetrics` — never instantiated, component times hardcoded to 0 - Deleted `getFloat16Array` from MLArrayCache — never called - Deleted `sliceEncoderOutput` from AsrTranscription — never called (30 lines) - Deleted `loadWithANEOptimization` from AsrModels — never called - Removed unused `tokenTimings` parameter chain through `processTranscriptionResult` - Removed unused `import OSLog` / `import CoreML` across 5 files - Removed `nonisolated(unsafe)` from SlidingWindowAsrManager (types already Sendable) ### Duplication elimination - Extracted `clearCachedCtcData()` helper (replaced 3× triple-nil assignments) - Extracted `decoderState(for:)` / `setDecoderState(_:for:)` (replaced 4× switch blocks) - Extracted `frameAlignedAudio()` (replaced 2× duplicated frame-alignment blocks) - Added `ASRConstants.secondsPerEncoderFrame` (replaced 5× magic `0.08`) - Replaced hardcoded `16_000` with `config.sampleRate` / `ASRConstants.sampleRate` - Extracted `MLModelConfigurationUtils.defaultConfiguration()` (replaced 5× copy-pasted config methods) - Extracted `MLModelConfigurationUtils.defaultModelsDirectory()` (replaced 3× copy-pasted directory methods) - Consolidated duplicate `vocabularyFile` / `vocabularyFileArray` constants ### File organization - Moved `PerformanceMetrics.swift`, `ProgressEmitter.swift`, `MLArrayCache.swift` from `ASR/Parakeet/` to `Shared/` (used by multiple modules) - Renamed `StreamingAudioSourceFactory` → `AudioSourceFactory`, `StreamingAudioSampleSource` → `AudioSampleSource` (types used by both ASR and Diarizer) - Renamed files to match type names: `SortformerDiarizerPipeline.swift` → `SortformerDiarizer.swift`, `LSEENDDiarizerAPI.swift` → `LSEENDDiarizer.swift`, `NemotronPipeline.swift` → `NemotronStreamingAsrManager+Pipeline.swift` - Replaced force unwraps in `RnntDecoder.swift` with `guard let` + descriptive errors - Removed stale TODO about decoder state in AsrManager ### Benchmark script - Added `Scripts/run_parakeet_benchmarks.sh` — runs all 6 benchmarks (v3, v2, TDT-CTC-110M, CTC earnings, EOU 320ms, Nemotron 1120ms) with WER comparison against `benchmarks100.md` baselines and regression detection - Referenced from `Documentation/ASR/benchmarks100.md` ## Verified — no regressions ``` Model Baseline Current Delta Parakeet TDT v3 (0.6B) 2.6% 2.64% +0.04% Parakeet TDT v2 (0.6B) 3.8% 3.79% -0.01% CTC-TDT 110M 3.6% 3.56% -0.04% CTC Earnings 16.54% 16.51% -0.03% EOU 320ms (120M) 7.11% 7.11% +0.00% Nemotron 1120ms (0.6B) 1.99% 1.99% +0.00% ``` ## Test plan - [x] `swift build` passes - [x] `swift test` passes (all existing tests, updated for removed dead code) - [x] All 6 ASR benchmarks match baselines (100 files each) - [ ] `swift format lint` passes
This commit is contained in:
@@ -102,6 +102,8 @@ Resources/
|
||||
!Sources/FluidAudio/Resources/
|
||||
!Sources/FluidAudio/Resources/**
|
||||
scripts/
|
||||
!Scripts/parakeet_subset_benchmark.sh
|
||||
!Scripts/diarizer_subset_benchmark.sh
|
||||
Documentation/parakeet-tdt/
|
||||
docs/parakeet-tdt/
|
||||
|
||||
|
||||
@@ -2,6 +2,18 @@
|
||||
|
||||
Benchmark comparison between `main` and PR #440 (`standardize-asr-directory-structure`) to verify the directory restructuring introduces no regressions.
|
||||
|
||||
## Reproduction
|
||||
|
||||
All batch TDT and CTC earnings benchmarks can be reproduced with [`Scripts/parakeet_subset_benchmark.sh`](../../Scripts/parakeet_subset_benchmark.sh):
|
||||
|
||||
```bash
|
||||
# Download models and datasets (requires internet)
|
||||
./Scripts/parakeet_subset_benchmark.sh --download
|
||||
|
||||
# Run all 4 benchmarks offline (100 files each, sleep-prevented)
|
||||
./Scripts/parakeet_subset_benchmark.sh
|
||||
```
|
||||
|
||||
## Environment
|
||||
|
||||
- **Hardware**: MacBook Air M2, 16 GB
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
# Diarization Benchmarks
|
||||
|
||||
Hardware: 2024 MacBook Pro, 48GB RAM, M4 Pro, macOS Tahoe 26.0
|
||||
|
||||
Dataset: AMI SDM (Single Distant Microphone), 4-meeting subset — one session per speaker group for diversity.
|
||||
|
||||
All results use collar=0.25s, ignoreOverlap=true.
|
||||
|
||||
## Summary
|
||||
|
||||
| System | Avg DER | Avg RTFx | Mode |
|
||||
|---|---|---|---|
|
||||
| LS-EEND (AMI) | 25.7% | 53.9x | Streaming |
|
||||
| Offline VBx | 21.8% | 97.5x | Offline |
|
||||
| Streaming 5s/0.8 | 29.9% | 96.2x | Streaming |
|
||||
| Sortformer (high-lat) | 34.3% | 120.3x | Streaming |
|
||||
|
||||
## Offline VBx
|
||||
|
||||
Pyannote segmentation + WeSpeaker embeddings + PLDA scoring + VBx clustering.
|
||||
|
||||
Default configuration: step ratio 0.2, minSegmentDurationSeconds 1.0, clustering threshold 0.7.
|
||||
|
||||
```bash
|
||||
Scripts/diarizer_subset_benchmark.sh
|
||||
# or manually:
|
||||
swift run -c release fluidaudiocli diarization-benchmark --mode offline \
|
||||
--dataset ami-sdm --auto-download
|
||||
```
|
||||
|
||||
```text
|
||||
----------------------------------------------------------------------
|
||||
Meeting DER % Miss % FA % SE % Speakers RTFx
|
||||
----------------------------------------------------------------------
|
||||
ES2004a 14.5 7.6 1.7 5.2 5/4 98.2
|
||||
IS1009a 17.7 3.6 3.0 11.1 6/4 99.1
|
||||
TS3003a 21.2 11.7 1.4 8.1 2/4 98.4
|
||||
EN2002a 33.9 4.5 1.4 28.0 4/4 94.2
|
||||
----------------------------------------------------------------------
|
||||
AVERAGE 21.8 6.9 1.9 13.1 - 97.5
|
||||
======================================================================
|
||||
```
|
||||
|
||||
Full VoxConverse results (232 clips): 15.07% DER, 122x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
|
||||
|
||||
## Streaming (5s chunks, 0.8 threshold)
|
||||
|
||||
Pyannote segmentation + WeSpeaker embeddings + online SpeakerManager clustering.
|
||||
|
||||
Best streaming configuration: 5s chunks, 0s overlap, 0.8 clustering threshold.
|
||||
|
||||
```bash
|
||||
Scripts/diarizer_subset_benchmark.sh
|
||||
# or manually:
|
||||
swift run -c release fluidaudiocli diarization-benchmark --mode streaming \
|
||||
--dataset ami-sdm --chunk-seconds 5.0 --overlap-seconds 0.0 \
|
||||
--threshold 0.8 --auto-download
|
||||
```
|
||||
|
||||
```text
|
||||
----------------------------------------------------------------------
|
||||
Meeting DER % Miss % FA % SE % Speakers RTFx
|
||||
----------------------------------------------------------------------
|
||||
ES2004a 17.0 9.0 1.3 6.7 7/4 99.2
|
||||
IS1009a 18.1 4.7 2.7 10.8 4/4 101.0
|
||||
TS3003a 21.0 12.7 1.4 6.8 2/4 104.3
|
||||
EN2002a 63.4 9.2 1.1 53.0 7/4 80.1
|
||||
----------------------------------------------------------------------
|
||||
AVERAGE 29.9 8.9 1.6 19.3 - 96.2
|
||||
======================================================================
|
||||
```
|
||||
|
||||
Full 7-meeting results: 26.2% DER, 223x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
|
||||
|
||||
EN2002a is a known difficult meeting for the streaming pipeline — aggressive speaker error (53%) due to over-fragmentation.
|
||||
|
||||
## Sortformer (NVIDIA High-Latency)
|
||||
|
||||
NVIDIA end-to-end Sortformer model, 30.4s chunk config.
|
||||
|
||||
Model: [FluidInference/diar-streaming-sortformer-coreml](https://huggingface.co/FluidInference/diar-streaming-sortformer-coreml)
|
||||
|
||||
```bash
|
||||
Scripts/diarizer_subset_benchmark.sh
|
||||
# or manually:
|
||||
swift run -c release fluidaudiocli sortformer-benchmark \
|
||||
--nvidia-high-latency --hf --auto-download
|
||||
```
|
||||
|
||||
```text
|
||||
----------------------------------------------------------------------
|
||||
Meeting DER % Miss % FA % SE % Speakers RTFx
|
||||
----------------------------------------------------------------------
|
||||
IS1009a 26.5 15.9 1.4 9.3 4/4 122.9
|
||||
ES2004a 33.4 24.5 0.1 8.8 4/4 117.9
|
||||
EN2002a 35.7 20.0 0.4 15.2 4/4 121.5
|
||||
TS3003a 41.8 36.8 0.7 4.3 4/4 119.0
|
||||
----------------------------------------------------------------------
|
||||
AVERAGE 34.3 24.3 0.7 9.4 - 120.3
|
||||
======================================================================
|
||||
```
|
||||
|
||||
Full 16-meeting results: 31.7% DER, 126.7x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
|
||||
|
||||
## LS-EEND (AMI variant)
|
||||
|
||||
Linear Streaming End-to-End Neural Diarization from Westlake University.
|
||||
|
||||
Model: [GradientDescent2718/ls-eend-coreml](https://huggingface.co/GradientDescent2718/ls-eend-coreml)
|
||||
|
||||
```bash
|
||||
Scripts/diarizer_subset_benchmark.sh
|
||||
# or manually:
|
||||
swift run -c release fluidaudiocli lseend-benchmark \
|
||||
--variant ami --auto-download
|
||||
```
|
||||
|
||||
```text
|
||||
----------------------------------------------------------------------
|
||||
Meeting DER % Miss % FA % SE % Speakers RTFx
|
||||
----------------------------------------------------------------------
|
||||
TS3003a 19.0 16.6 0.8 1.6 4/4 47.5
|
||||
IS1009a 23.4 8.0 2.6 12.8 4/4 57.7
|
||||
EN2002a 24.5 19.7 1.1 3.6 4/4 53.2
|
||||
ES2004a 35.8 13.3 19.2 3.2 4/4 57.2
|
||||
----------------------------------------------------------------------
|
||||
AVERAGE 25.7 14.4 5.9 5.3 - 53.9
|
||||
======================================================================
|
||||
```
|
||||
|
||||
Full 16-meeting results: 20.7% DER, 74.5x RTFx. See [Benchmarks.md](../Benchmarks.md) for details.
|
||||
|
||||
## Reproducing
|
||||
|
||||
Run all 4 systems on the default 4-meeting subset:
|
||||
|
||||
```bash
|
||||
./Scripts/diarizer_subset_benchmark.sh
|
||||
```
|
||||
|
||||
Run on all 16 AMI meetings:
|
||||
|
||||
```bash
|
||||
./Scripts/diarizer_subset_benchmark.sh --all
|
||||
```
|
||||
|
||||
Results are saved to `benchmark_results/` with timestamps. The script uses `caffeinate` to prevent sleep during long runs.
|
||||
Executable
+478
@@ -0,0 +1,478 @@
|
||||
#!/bin/bash
|
||||
# Run all diarizer model benchmarks on AMI SDM with sleep prevention.
|
||||
#
|
||||
# Benchmarks:
|
||||
# 1. Offline (VBx) — OfflineDiarizerManager, step=0.2, min-seg=1.0
|
||||
# 2. Streaming (5s) — DiarizerManager, 5s chunks, 0s overlap, threshold=0.8
|
||||
# 3. Sortformer — SortformerDiarizer, NVIDIA high-latency config
|
||||
# 4. LS-EEND — LSEENDDiarizer, AMI variant
|
||||
#
|
||||
# Usage:
|
||||
# ./Scripts/diarizer_subset_benchmark.sh # quick run (4 meetings)
|
||||
# ./Scripts/diarizer_subset_benchmark.sh --all # full run (all 16 meetings)
|
||||
# ./Scripts/diarizer_subset_benchmark.sh --max-files 8 # custom subset
|
||||
# ./Scripts/diarizer_subset_benchmark.sh --download # download missing assets, then exit
|
||||
#
|
||||
# The script verifies all models and dataset files exist locally before running.
|
||||
# If anything is missing it will tell you exactly what and exit (unless --download).
|
||||
# Uses caffeinate to prevent sleep so you can close the lid.
|
||||
# Results are saved to benchmark_results/ with timestamps.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
RESULTS_DIR="$PROJECT_DIR/benchmark_results"
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
LOG_FILE="$RESULTS_DIR/diarizer_benchmark_${TIMESTAMP}.log"
|
||||
|
||||
MODELS_DIR="$HOME/Library/Application Support/FluidAudio/Models"
|
||||
DATASETS_DIR="$HOME/FluidAudioDatasets"
|
||||
AMI_SDM_DIR="$DATASETS_DIR/ami_official/sdm"
|
||||
AMI_RTTM_DIR="$DATASETS_DIR/ami_official/rttm"
|
||||
MAX_FILES=4 # default: quick 4-meeting subset
|
||||
|
||||
# AMI SDM has 16 meetings — this is the standard diarization test set.
|
||||
# Ordered so the first N picks one from each speaker group for maximum diversity.
|
||||
# Groups: EN2002 (4 speakers), ES2004 (4), IS1009 (4), TS3003 (4)
|
||||
ALL_AMI_MEETINGS=(
|
||||
EN2002a ES2004a IS1009a TS3003a
|
||||
EN2002b ES2004b IS1009b TS3003b
|
||||
EN2002c ES2004c IS1009c TS3003c
|
||||
EN2002d ES2004d IS1009d TS3003d
|
||||
)
|
||||
|
||||
# Parse --all / --max-files <N> from arguments
|
||||
args=("$@")
|
||||
for ((i=0; i<${#args[@]}; i++)); do
|
||||
case "${args[$i]}" in
|
||||
--all) MAX_FILES=${#ALL_AMI_MEETINGS[@]} ;;
|
||||
--max-files) MAX_FILES="${args[$((i+1))]}" ; i=$((i+1)) ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Select the subset of meetings to run
|
||||
AMI_MEETINGS=("${ALL_AMI_MEETINGS[@]:0:$MAX_FILES}")
|
||||
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Verify local assets
|
||||
# ---------------------------------------------------------------------------
|
||||
verify_assets() {
|
||||
local missing=0
|
||||
|
||||
# --- AMI SDM audio files ---
|
||||
local wav_count=0
|
||||
for meeting in "${AMI_MEETINGS[@]}"; do
|
||||
if [[ -f "$AMI_SDM_DIR/${meeting}.Mix-Headset.wav" ]]; then
|
||||
wav_count=$((wav_count + 1))
|
||||
else
|
||||
log "MISSING AMI SDM: $AMI_SDM_DIR/${meeting}.Mix-Headset.wav"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
if [[ "$wav_count" -eq 0 ]]; then
|
||||
log "MISSING AMI SDM: no wav files found in $AMI_SDM_DIR"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
# --- AMI RTTM annotations (downloaded automatically by --auto-download) ---
|
||||
local rttm_count=0
|
||||
for meeting in "${ALL_AMI_MEETINGS[@]}"; do
|
||||
if [[ -f "$AMI_RTTM_DIR/${meeting}.rttm" ]]; then
|
||||
rttm_count=$((rttm_count + 1))
|
||||
fi
|
||||
done
|
||||
if [[ "$rttm_count" -eq 0 ]]; then
|
||||
log "NOTE AMI RTTM annotations not found — will be auto-downloaded by CLI"
|
||||
fi
|
||||
|
||||
# --- Offline diarizer models (pyannote segmentation + wespeaker embedding) ---
|
||||
local diar_dir="$MODELS_DIR/speaker-diarization-coreml"
|
||||
if [[ ! -d "$diar_dir" ]]; then
|
||||
log "MISSING Diarizer models: $diar_dir"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
# --- Sortformer models (folder may or may not have -coreml suffix) ---
|
||||
if [[ ! -d "$MODELS_DIR/diar-streaming-sortformer-coreml" ]] && [[ ! -d "$MODELS_DIR/diar-streaming-sortformer" ]]; then
|
||||
log "MISSING Sortformer models: $MODELS_DIR/diar-streaming-sortformer{,-coreml}"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
# --- LS-EEND models (folder may or may not have -coreml suffix) ---
|
||||
if [[ ! -d "$MODELS_DIR/ls-eend-coreml" ]] && [[ ! -d "$MODELS_DIR/ls-eend" ]]; then
|
||||
log "MISSING LS-EEND models: $MODELS_DIR/ls-eend{,-coreml}"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
return $missing
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: --download (verify first, download only what's missing)
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ "${1:-}" == "--download" ]]; then
|
||||
log "=== Checking local assets ==="
|
||||
|
||||
if verify_assets; then
|
||||
log "All models and datasets already present locally. Nothing to download."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
log "Some assets are missing — downloading..."
|
||||
|
||||
log "Building release binary..."
|
||||
cd "$PROJECT_DIR" && swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
|
||||
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
|
||||
|
||||
log "Downloading AMI SDM dataset + annotations..."
|
||||
"$CLI" diarization-benchmark --mode offline --auto-download --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_offline.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading Sortformer models..."
|
||||
"$CLI" sortformer-benchmark --nvidia-high-latency --hf --auto-download --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_sortformer.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading LS-EEND models..."
|
||||
"$CLI" lseend-benchmark --variant ami --auto-download --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_lseend.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
rm -f "$RESULTS_DIR"/warmup_*.json
|
||||
log "=== Downloads complete ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2: Run benchmarks (offline-safe, sleep-prevented)
|
||||
# ---------------------------------------------------------------------------
|
||||
log "=== Verifying local assets before offline run ==="
|
||||
if ! verify_assets; then
|
||||
log ""
|
||||
log "ERROR: Missing assets — cannot run offline."
|
||||
log "Run with --download first while connected to the internet:"
|
||||
log " ./Scripts/diarizer_subset_benchmark.sh --download"
|
||||
exit 1
|
||||
fi
|
||||
log "All assets verified locally."
|
||||
|
||||
log "=== Diarizer benchmark suite: ${#AMI_MEETINGS[@]}/${#ALL_AMI_MEETINGS[@]} meetings x 4 systems ==="
|
||||
log "Results directory: $RESULTS_DIR"
|
||||
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# Build release if not already built
|
||||
if [[ ! -x ".build/release/fluidaudiocli" ]]; then
|
||||
log "Building release binary..."
|
||||
swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
|
||||
fi
|
||||
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
|
||||
|
||||
# caffeinate -s: prevent sleep even on AC power / lid closed
|
||||
# caffeinate -i: prevent idle sleep
|
||||
caffeinate -si -w $$ &
|
||||
CAFFEINATE_PID=$!
|
||||
log "caffeinate started (PID $CAFFEINATE_PID) — safe to close the lid"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark runners
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Run a benchmark for each meeting via --single-file, then merge JSON results.
|
||||
# This ensures we control exactly which meetings run (not the CLI's internal order).
|
||||
merge_json_results() {
|
||||
local output_file="$1"
|
||||
shift
|
||||
local tmp_files=("$@")
|
||||
python3 -c "
|
||||
import json, sys
|
||||
results = []
|
||||
for f in sys.argv[2:]:
|
||||
try:
|
||||
with open(f) as fh:
|
||||
data = json.load(fh)
|
||||
if isinstance(data, list):
|
||||
results.extend(data)
|
||||
else:
|
||||
results.append(data)
|
||||
except: pass
|
||||
with open(sys.argv[1], 'w') as out:
|
||||
json.dump(results, out, indent=2)
|
||||
" "$output_file" "${tmp_files[@]}" 2>/dev/null
|
||||
rm -f "${tmp_files[@]}"
|
||||
}
|
||||
|
||||
run_offline_benchmark() {
|
||||
local label="offline_vbx"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
local tmp_files=()
|
||||
|
||||
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, offline VBx) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
for meeting in "${AMI_MEETINGS[@]}"; do
|
||||
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
|
||||
tmp_files+=("$tmp")
|
||||
log " [$label] $meeting"
|
||||
"$CLI" diarization-benchmark \
|
||||
--mode offline \
|
||||
--dataset ami-sdm \
|
||||
--single-file "$meeting" \
|
||||
--auto-download \
|
||||
--output "$tmp" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
done
|
||||
|
||||
merge_json_results "$output_file" "${tmp_files[@]}"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_streaming_benchmark() {
|
||||
local label="streaming_5s"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
local tmp_files=()
|
||||
|
||||
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, 5s chunks, threshold=0.8) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
for meeting in "${AMI_MEETINGS[@]}"; do
|
||||
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
|
||||
tmp_files+=("$tmp")
|
||||
log " [$label] $meeting"
|
||||
"$CLI" diarization-benchmark \
|
||||
--mode streaming \
|
||||
--dataset ami-sdm \
|
||||
--single-file "$meeting" \
|
||||
--chunk-seconds 5.0 \
|
||||
--overlap-seconds 0.0 \
|
||||
--threshold 0.8 \
|
||||
--auto-download \
|
||||
--output "$tmp" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
done
|
||||
|
||||
merge_json_results "$output_file" "${tmp_files[@]}"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_sortformer_benchmark() {
|
||||
local label="sortformer"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
local tmp_files=()
|
||||
|
||||
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, NVIDIA high-latency) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
for meeting in "${AMI_MEETINGS[@]}"; do
|
||||
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
|
||||
tmp_files+=("$tmp")
|
||||
log " [$label] $meeting"
|
||||
"$CLI" sortformer-benchmark \
|
||||
--nvidia-high-latency \
|
||||
--hf \
|
||||
--dataset ami \
|
||||
--single-file "$meeting" \
|
||||
--auto-download \
|
||||
--output "$tmp" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
done
|
||||
|
||||
merge_json_results "$output_file" "${tmp_files[@]}"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_lseend_benchmark() {
|
||||
local label="lseend_ami"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
local tmp_files=()
|
||||
|
||||
log "--- $label: starting (${#AMI_MEETINGS[@]} meetings, AMI SDM, AMI variant) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
for meeting in "${AMI_MEETINGS[@]}"; do
|
||||
local tmp="$RESULTS_DIR/${label}_tmp_${meeting}.json"
|
||||
tmp_files+=("$tmp")
|
||||
log " [$label] $meeting"
|
||||
"$CLI" lseend-benchmark \
|
||||
--variant ami \
|
||||
--dataset ami \
|
||||
--single-file "$meeting" \
|
||||
--auto-download \
|
||||
--output "$tmp" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
done
|
||||
|
||||
merge_json_results "$output_file" "${tmp_files[@]}"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run all 4 benchmarks
|
||||
# ---------------------------------------------------------------------------
|
||||
SUITE_START=$(date +%s)
|
||||
|
||||
run_offline_benchmark
|
||||
run_streaming_benchmark
|
||||
run_sortformer_benchmark
|
||||
run_lseend_benchmark
|
||||
|
||||
SUITE_END=$(date +%s)
|
||||
SUITE_ELAPSED=$(( SUITE_END - SUITE_START ))
|
||||
|
||||
log "=== All benchmarks complete in ${SUITE_ELAPSED}s ==="
|
||||
log "Results:"
|
||||
ls -lh "$RESULTS_DIR"/*_${TIMESTAMP}.json 2>/dev/null | tee -a "$LOG_FILE"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extract DER and RTFx from JSON results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Streaming diarization benchmark: JSON is array of per-meeting results with "der" and "rtfx"
|
||||
extract_streaming_metrics() {
|
||||
local json_file="$1"
|
||||
if [[ -f "$json_file" ]]; then
|
||||
python3 -c "
|
||||
import json, sys
|
||||
with open('$json_file') as f:
|
||||
results = json.load(f)
|
||||
if not results:
|
||||
print('N/A N/A')
|
||||
sys.exit()
|
||||
avg_der = sum(r['der'] for r in results) / len(results)
|
||||
avg_rtfx = sum(r['rtfx'] for r in results) / len(results)
|
||||
print(f'{avg_der:.1f} {avg_rtfx:.1f}')
|
||||
" 2>/dev/null || echo "N/A N/A"
|
||||
else
|
||||
echo "N/A N/A"
|
||||
fi
|
||||
}
|
||||
|
||||
# Sortformer/LS-EEND: same JSON array format via DiarizationBenchmarkUtils
|
||||
extract_shared_metrics() {
|
||||
local json_file="$1"
|
||||
if [[ -f "$json_file" ]]; then
|
||||
python3 -c "
|
||||
import json, sys
|
||||
with open('$json_file') as f:
|
||||
results = json.load(f)
|
||||
if not results:
|
||||
print('N/A N/A')
|
||||
sys.exit()
|
||||
avg_der = sum(r['der'] for r in results) / len(results)
|
||||
avg_rtfx = sum(r['rtfx'] for r in results) / len(results)
|
||||
print(f'{avg_der:.1f} {avg_rtfx:.1f}')
|
||||
" 2>/dev/null || echo "N/A N/A"
|
||||
else
|
||||
echo "N/A N/A"
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compare DER & RTFx against Benchmarks.md baselines
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Baselines from Documentation/Benchmarks.md (AMI SDM, all 16 meetings)
|
||||
# Note: when running a subset (--max-files <16), DER will differ from these baselines
|
||||
# due to per-meeting variance. Baselines are for full 16-meeting runs only.
|
||||
# Offline: no AMI SDM baseline yet — first --all run establishes it.
|
||||
# Streaming: 5s/0s/0.8 on AMI SDM (7 meetings) = 26.2% DER, 223.1x RTFx
|
||||
# Sortformer: NVIDIA high-latency on AMI SDM (16 meetings) = 31.7% DER, 126.7x RTFx
|
||||
# LS-EEND: AMI variant on AMI SDM (16 meetings) = 20.7% DER, 74.5x RTFx
|
||||
BASELINE_STREAMING_DER="26.2"
|
||||
BASELINE_STREAMING_RTFX="223.1"
|
||||
BASELINE_SORTFORMER_DER="31.7"
|
||||
BASELINE_SORTFORMER_RTFX="126.7"
|
||||
BASELINE_LSEEND_DER="20.7"
|
||||
BASELINE_LSEEND_RTFX="74.5"
|
||||
|
||||
OFFLINE_FILE="$RESULTS_DIR/offline_vbx_${TIMESTAMP}.json"
|
||||
STREAMING_FILE="$RESULTS_DIR/streaming_5s_${TIMESTAMP}.json"
|
||||
SORTFORMER_FILE="$RESULTS_DIR/sortformer_${TIMESTAMP}.json"
|
||||
LSEEND_FILE="$RESULTS_DIR/lseend_ami_${TIMESTAMP}.json"
|
||||
|
||||
read OFFLINE_DER OFFLINE_RTFX <<< $(extract_streaming_metrics "$OFFLINE_FILE")
|
||||
read STREAMING_DER STREAMING_RTFX <<< $(extract_streaming_metrics "$STREAMING_FILE")
|
||||
read SORTFORMER_DER SORTFORMER_RTFX <<< $(extract_shared_metrics "$SORTFORMER_FILE")
|
||||
read LSEEND_DER LSEEND_RTFX <<< $(extract_shared_metrics "$LSEEND_FILE")
|
||||
|
||||
log ""
|
||||
log "=== DER & RTFx Comparison vs Benchmarks.md baselines (AMI SDM, ${#AMI_MEETINGS[@]} meetings) ==="
|
||||
log ""
|
||||
printf "%-25s %12s %12s %12s %12s %12s\n" \
|
||||
"System" "Base DER" "DER" "Delta" "Base RTFx" "RTFx" | tee -a "$LOG_FILE"
|
||||
printf "%-25s %12s %12s %12s %12s %12s\n" \
|
||||
"-------------------------" "------------" "------------" "------------" "------------" "------------" | tee -a "$LOG_FILE"
|
||||
|
||||
compare_der_rtfx() {
|
||||
local label="$1" base_der="$2" current_der="$3" base_rtfx="$4" current_rtfx="$5"
|
||||
|
||||
if [[ "$current_der" == "N/A" ]]; then
|
||||
printf "%-25s %11s%% %12s %12s %11sx %12s\n" \
|
||||
"$label" "$base_der" "N/A" "—" "$base_rtfx" "N/A" | tee -a "$LOG_FILE"
|
||||
return
|
||||
fi
|
||||
|
||||
local delta marker=""
|
||||
delta=$(python3 -c "print(f'{$current_der - $base_der:+.1f}')" 2>/dev/null || echo "?")
|
||||
local regression
|
||||
regression=$(python3 -c "print('YES' if $current_der > $base_der + 2.0 else 'NO')" 2>/dev/null || echo "NO")
|
||||
if [[ "$regression" == "YES" ]]; then
|
||||
marker=" <- REGRESSION"
|
||||
fi
|
||||
|
||||
printf "%-25s %11s%% %11s%% %11s%% %11sx %11sx%s\n" \
|
||||
"$label" "$base_der" "$current_der" "$delta" "$base_rtfx" "$current_rtfx" "$marker" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
# Offline has no AMI SDM baseline yet — show as "new"
|
||||
if [[ "$OFFLINE_DER" != "N/A" ]]; then
|
||||
printf "%-25s %12s %11s%% %12s %12s %11sx\n" \
|
||||
"Offline (VBx)" "—" "$OFFLINE_DER" "(new)" "—" "$OFFLINE_RTFX" | tee -a "$LOG_FILE"
|
||||
else
|
||||
printf "%-25s %12s %12s %12s %12s %12s\n" \
|
||||
"Offline (VBx)" "—" "N/A" "—" "—" "N/A" | tee -a "$LOG_FILE"
|
||||
fi
|
||||
|
||||
compare_der_rtfx "Streaming (5s/0.8)" "$BASELINE_STREAMING_DER" "$STREAMING_DER" "$BASELINE_STREAMING_RTFX" "$STREAMING_RTFX"
|
||||
compare_der_rtfx "Sortformer (high-lat)" "$BASELINE_SORTFORMER_DER" "$SORTFORMER_DER" "$BASELINE_SORTFORMER_RTFX" "$SORTFORMER_RTFX"
|
||||
compare_der_rtfx "LS-EEND (AMI)" "$BASELINE_LSEEND_DER" "$LSEEND_DER" "$BASELINE_LSEEND_RTFX" "$LSEEND_RTFX"
|
||||
|
||||
log ""
|
||||
|
||||
# Check for any DER regressions (>2.0% increase — diarization is noisier than ASR)
|
||||
ANY_REGRESSION=$(python3 -c "
|
||||
baselines = [
|
||||
($BASELINE_STREAMING_DER, '$STREAMING_DER'),
|
||||
($BASELINE_SORTFORMER_DER, '$SORTFORMER_DER'),
|
||||
($BASELINE_LSEEND_DER, '$LSEEND_DER'),
|
||||
]
|
||||
for b, c in baselines:
|
||||
if c != 'N/A' and float(c) > b + 2.0:
|
||||
print('YES'); exit()
|
||||
print('NO')
|
||||
" 2>/dev/null || echo "NO")
|
||||
|
||||
if [[ "$ANY_REGRESSION" == "YES" ]]; then
|
||||
log "WARNING: DER REGRESSION DETECTED (>2.0% above baseline) — investigate before merging"
|
||||
else
|
||||
log "No DER regressions (all within 2.0% of baseline)"
|
||||
fi
|
||||
|
||||
# caffeinate will exit automatically since the parent process ($$) exits
|
||||
Executable
+398
@@ -0,0 +1,398 @@
|
||||
#!/bin/bash
|
||||
# Run all Parakeet model benchmarks (100 files each) with sleep prevention.
|
||||
#
|
||||
# Benchmarks:
|
||||
# 1. ASR v3 — parakeet-tdt-0.6b-v3 on LibriSpeech test-clean
|
||||
# 2. ASR v2 — parakeet-tdt-0.6b-v2 on LibriSpeech test-clean
|
||||
# 3. ASR tdt-ctc-110m — parakeet-tdt-ctc-110m on LibriSpeech test-clean
|
||||
# 4. CTC custom vocab — ctc-earnings-benchmark (v2 TDT + CTC 110m keyword spotting)
|
||||
# 5. EOU streaming — parakeet-eou 320ms on LibriSpeech test-clean
|
||||
# 6. Nemotron streaming — nemotron 1120ms on LibriSpeech test-clean
|
||||
#
|
||||
# Usage:
|
||||
# ./Scripts/parakeet_subset_benchmark.sh # verify + run
|
||||
# ./Scripts/parakeet_subset_benchmark.sh --download # download missing assets, then exit
|
||||
#
|
||||
# The script verifies all models and dataset files exist locally before running.
|
||||
# If anything is missing it will tell you exactly what and exit (unless --download).
|
||||
# Uses caffeinate to prevent sleep so you can close the lid.
|
||||
# Results are saved to benchmark_results/ with timestamps.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
RESULTS_DIR="$PROJECT_DIR/benchmark_results"
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
LOG_FILE="$RESULTS_DIR/benchmark_${TIMESTAMP}.log"
|
||||
MAX_FILES=100
|
||||
SUBSET="test-clean"
|
||||
|
||||
MODELS_DIR="$HOME/Library/Application Support/FluidAudio/Models"
|
||||
DATASETS_DIR="$HOME/Library/Application Support/FluidAudio/Datasets"
|
||||
EARNINGS_DIR="$HOME/Library/Application Support/FluidAudio/earnings22-kws/test-dataset"
|
||||
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Verify local assets
|
||||
# ---------------------------------------------------------------------------
|
||||
verify_assets() {
|
||||
local missing=0
|
||||
|
||||
# --- Parakeet v3 ---
|
||||
local v3_dir="$MODELS_DIR/parakeet-tdt-0.6b-v3"
|
||||
for f in Preprocessor.mlmodelc Encoder.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
|
||||
if [[ ! -e "$v3_dir/$f" ]]; then
|
||||
log "MISSING v3: $v3_dir/$f"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
|
||||
# --- Parakeet v2 (folder may have -coreml suffix) ---
|
||||
local v2_dir=""
|
||||
if [[ -d "$MODELS_DIR/parakeet-tdt-0.6b-v2-coreml" ]]; then
|
||||
v2_dir="$MODELS_DIR/parakeet-tdt-0.6b-v2-coreml"
|
||||
elif [[ -d "$MODELS_DIR/parakeet-tdt-0.6b-v2" ]]; then
|
||||
v2_dir="$MODELS_DIR/parakeet-tdt-0.6b-v2"
|
||||
fi
|
||||
if [[ -z "$v2_dir" ]]; then
|
||||
log "MISSING v2: no parakeet-tdt-0.6b-v2* directory found"
|
||||
missing=1
|
||||
else
|
||||
for f in Preprocessor.mlmodelc Encoder.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
|
||||
if [[ ! -e "$v2_dir/$f" ]]; then
|
||||
log "MISSING v2: $v2_dir/$f"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
# --- TDT-CTC-110M (fused: no separate Encoder) ---
|
||||
local tdt_ctc_dir="$MODELS_DIR/parakeet-tdt-ctc-110m"
|
||||
for f in Preprocessor.mlmodelc Decoder.mlmodelc JointDecision.mlmodelc parakeet_vocab.json; do
|
||||
if [[ ! -e "$tdt_ctc_dir/$f" ]]; then
|
||||
log "MISSING tdt-ctc-110m: $tdt_ctc_dir/$f"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
|
||||
# --- CTC 110M model (for custom vocabulary / keyword spotting) ---
|
||||
local ctc_dir="$MODELS_DIR/parakeet-ctc-110m-coreml"
|
||||
for f in MelSpectrogram.mlmodelc AudioEncoder.mlmodelc vocab.json; do
|
||||
if [[ ! -e "$ctc_dir/$f" ]]; then
|
||||
log "MISSING ctc-110m: $ctc_dir/$f"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
|
||||
# --- EOU streaming models (320ms chunks) ---
|
||||
local eou_dir="$MODELS_DIR/parakeet-eou-streaming/320ms"
|
||||
if [[ ! -d "$eou_dir" ]]; then
|
||||
log "MISSING eou-320ms: $eou_dir"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
# --- Nemotron models (uses v3 encoder + nemotron-specific models) ---
|
||||
# Nemotron reuses the v3 models directory; no separate check needed beyond v3 above.
|
||||
|
||||
# --- LibriSpeech test-clean ---
|
||||
local ls_dir="$DATASETS_DIR/LibriSpeech/$SUBSET"
|
||||
local trans_count
|
||||
trans_count=$(find "$ls_dir" -name "*.trans.txt" 2>/dev/null | wc -l | tr -d ' ')
|
||||
if [[ "$trans_count" -lt 5 ]]; then
|
||||
log "MISSING LibriSpeech $SUBSET: found $trans_count transcript files (need >= 5)"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
# --- Earnings22 KWS dataset ---
|
||||
local earnings_wav_count
|
||||
earnings_wav_count=$(find "$EARNINGS_DIR" -maxdepth 1 -name "*.wav" 2>/dev/null | wc -l | tr -d ' ')
|
||||
if [[ "$earnings_wav_count" -lt 10 ]]; then
|
||||
log "MISSING Earnings22 KWS: found $earnings_wav_count wav files (need >= 10)"
|
||||
missing=1
|
||||
fi
|
||||
|
||||
return $missing
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: --download (verify first, download only what's missing)
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ "${1:-}" == "--download" ]]; then
|
||||
log "=== Checking local assets ==="
|
||||
|
||||
if verify_assets; then
|
||||
log "All models and datasets already present locally. Nothing to download."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
log "Some assets are missing — downloading..."
|
||||
|
||||
log "Building release binary..."
|
||||
cd "$PROJECT_DIR" && swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
|
||||
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
|
||||
|
||||
log "Downloading LibriSpeech $SUBSET dataset..."
|
||||
"$CLI" download --dataset "librispeech-$SUBSET" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Downloading Earnings22 KWS dataset..."
|
||||
"$CLI" download --dataset earnings22-kws 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading Parakeet v3 models (triggers download if missing)..."
|
||||
"$CLI" asr-benchmark --model-version v3 --subset "$SUBSET" --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_v3.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading Parakeet v2 models..."
|
||||
"$CLI" asr-benchmark --model-version v2 --subset "$SUBSET" --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_v2.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading CTC earnings models..."
|
||||
"$CLI" ctc-earnings-benchmark --max-files 1 --auto-download \
|
||||
--output "$RESULTS_DIR/warmup_ctc.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading EOU streaming models..."
|
||||
"$CLI" parakeet-eou --benchmark --chunk-size 320 --max-files 1 \
|
||||
--output "$RESULTS_DIR/warmup_eou.json" 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
log "Pre-loading Nemotron streaming models..."
|
||||
"$CLI" nemotron-benchmark --max-files 1 2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
rm -f "$RESULTS_DIR"/warmup_*.json /tmp/nemotron_*_benchmark.json
|
||||
log "=== Downloads complete ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2: Run benchmarks (offline-safe, sleep-prevented)
|
||||
# ---------------------------------------------------------------------------
|
||||
log "=== Verifying local assets before offline run ==="
|
||||
if ! verify_assets; then
|
||||
log ""
|
||||
log "ERROR: Missing assets — cannot run offline."
|
||||
log "Run with --download first while connected to the internet:"
|
||||
log " ./Scripts/parakeet_subset_benchmark.sh --download"
|
||||
exit 1
|
||||
fi
|
||||
log "All assets verified locally."
|
||||
|
||||
log "=== Parakeet benchmark suite: $MAX_FILES files x 6 benchmarks ==="
|
||||
log "Results directory: $RESULTS_DIR"
|
||||
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# Build release if not already built
|
||||
if [[ ! -x ".build/release/fluidaudiocli" ]]; then
|
||||
log "Building release binary..."
|
||||
swift build -c release 2>&1 | tail -1 | tee -a "$LOG_FILE"
|
||||
fi
|
||||
CLI="$PROJECT_DIR/.build/release/fluidaudiocli"
|
||||
|
||||
# caffeinate -s: prevent sleep even on AC power / lid closed
|
||||
# caffeinate -i: prevent idle sleep
|
||||
# We wrap the entire benchmark suite so caffeinate dies when the script ends.
|
||||
caffeinate -si -w $$ &
|
||||
CAFFEINATE_PID=$!
|
||||
log "caffeinate started (PID $CAFFEINATE_PID) — safe to close the lid"
|
||||
|
||||
run_asr_benchmark() {
|
||||
local model_version="$1"
|
||||
local label="$2"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
|
||||
log "--- $label: starting ($MAX_FILES files, $SUBSET) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
"$CLI" asr-benchmark \
|
||||
--model-version "$model_version" \
|
||||
--subset "$SUBSET" \
|
||||
--max-files "$MAX_FILES" \
|
||||
--no-auto-download \
|
||||
--output "$output_file" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_ctc_earnings_benchmark() {
|
||||
local label="ctc_earnings_vocab"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
|
||||
log "--- $label: starting ($MAX_FILES files, v2 TDT + CTC keyword spotting) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
# TDT v2 is used for transcription to match benchmarks100.md baseline
|
||||
"$CLI" ctc-earnings-benchmark \
|
||||
--ctc-variant 110m \
|
||||
--max-files "$MAX_FILES" \
|
||||
--output "$output_file" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_eou_benchmark() {
|
||||
local label="eou_320ms"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
|
||||
log "--- $label: starting ($MAX_FILES files, $SUBSET, 320ms chunks) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
"$CLI" parakeet-eou \
|
||||
--benchmark \
|
||||
--chunk-size 320 \
|
||||
--max-files "$MAX_FILES" \
|
||||
--use-cache \
|
||||
--output "$output_file" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
run_nemotron_benchmark() {
|
||||
local label="nemotron_1120ms"
|
||||
local output_file="$RESULTS_DIR/${label}_${TIMESTAMP}.json"
|
||||
|
||||
log "--- $label: starting ($MAX_FILES files, $SUBSET, 1120ms chunks) ---"
|
||||
local start_time=$(date +%s)
|
||||
|
||||
"$CLI" nemotron-benchmark \
|
||||
--max-files "$MAX_FILES" \
|
||||
2>&1 | tee -a "$LOG_FILE"
|
||||
|
||||
# Nemotron writes to /tmp; copy to our results dir
|
||||
local tmp_file="/tmp/nemotron_1120ms_benchmark.json"
|
||||
if [[ -f "$tmp_file" ]]; then
|
||||
cp "$tmp_file" "$output_file"
|
||||
fi
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local elapsed=$(( end_time - start_time ))
|
||||
log "--- $label: finished in ${elapsed}s — $output_file ---"
|
||||
}
|
||||
|
||||
SUITE_START=$(date +%s)
|
||||
|
||||
run_asr_benchmark "v3" "parakeet_v3"
|
||||
run_asr_benchmark "v2" "parakeet_v2"
|
||||
run_asr_benchmark "tdt-ctc-110m" "parakeet_tdt_ctc_110m"
|
||||
run_ctc_earnings_benchmark
|
||||
run_eou_benchmark
|
||||
run_nemotron_benchmark
|
||||
|
||||
SUITE_END=$(date +%s)
|
||||
SUITE_ELAPSED=$(( SUITE_END - SUITE_START ))
|
||||
|
||||
log "=== All benchmarks complete in ${SUITE_ELAPSED}s ==="
|
||||
log "Results:"
|
||||
ls -lh "$RESULTS_DIR"/*_${TIMESTAMP}.json 2>/dev/null | tee -a "$LOG_FILE"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compare WER against benchmarks100.md baselines
|
||||
# ---------------------------------------------------------------------------
|
||||
# Baselines from Documentation/ASR/benchmarks100.md (main column)
|
||||
BASELINE_V3_WER="2.6"
|
||||
BASELINE_V2_WER="3.8"
|
||||
BASELINE_TDT_CTC_WER="3.6"
|
||||
BASELINE_EARNINGS_WER="16.54"
|
||||
BASELINE_EOU_WER="7.11"
|
||||
BASELINE_NEMOTRON_WER="1.99"
|
||||
|
||||
extract_wer() {
|
||||
local json_file="$1"
|
||||
local field="$2"
|
||||
if [[ -f "$json_file" ]]; then
|
||||
python3 -c "import json,sys; d=json.load(open('$json_file')); print(round(d['summary']['$field']*100, 2))" 2>/dev/null || echo "N/A"
|
||||
else
|
||||
echo "N/A"
|
||||
fi
|
||||
}
|
||||
|
||||
# For JSON fields that already store WER as a percentage (not decimal)
|
||||
extract_wer_pct() {
|
||||
local json_file="$1"
|
||||
local section="$2"
|
||||
local field="$3"
|
||||
if [[ -f "$json_file" ]]; then
|
||||
if [[ -n "$section" ]]; then
|
||||
python3 -c "import json; d=json.load(open('$json_file')); print(round(d['$section']['$field'], 2))" 2>/dev/null || echo "N/A"
|
||||
else
|
||||
python3 -c "import json; d=json.load(open('$json_file')); print(round(d['$field'], 2))" 2>/dev/null || echo "N/A"
|
||||
fi
|
||||
else
|
||||
echo "N/A"
|
||||
fi
|
||||
}
|
||||
|
||||
V3_FILE="$RESULTS_DIR/parakeet_v3_${TIMESTAMP}.json"
|
||||
V2_FILE="$RESULTS_DIR/parakeet_v2_${TIMESTAMP}.json"
|
||||
TDT_CTC_FILE="$RESULTS_DIR/parakeet_tdt_ctc_110m_${TIMESTAMP}.json"
|
||||
EARNINGS_FILE="$RESULTS_DIR/ctc_earnings_vocab_${TIMESTAMP}.json"
|
||||
EOU_FILE="$RESULTS_DIR/eou_320ms_${TIMESTAMP}.json"
|
||||
NEMOTRON_FILE="$RESULTS_DIR/nemotron_1120ms_${TIMESTAMP}.json"
|
||||
|
||||
V3_WER=$(extract_wer "$V3_FILE" "averageWER")
|
||||
V2_WER=$(extract_wer "$V2_FILE" "averageWER")
|
||||
TDT_CTC_WER=$(extract_wer "$TDT_CTC_FILE" "averageWER")
|
||||
EARNINGS_WER=$(extract_wer_pct "$EARNINGS_FILE" "summary" "avgWer")
|
||||
EOU_WER=$(extract_wer "$EOU_FILE" "averageWER")
|
||||
NEMOTRON_WER=$(extract_wer_pct "$NEMOTRON_FILE" "" "wer")
|
||||
|
||||
log ""
|
||||
log "=== WER Comparison vs benchmarks100.md baselines ==="
|
||||
log ""
|
||||
printf "%-25s %10s %10s %10s\n" "Model" "Baseline" "Current" "Delta" | tee -a "$LOG_FILE"
|
||||
printf "%-25s %10s %10s %10s\n" "-------------------------" "----------" "----------" "----------" | tee -a "$LOG_FILE"
|
||||
|
||||
compare_wer() {
|
||||
local label="$1" baseline="$2" current="$3"
|
||||
if [[ "$current" == "N/A" ]]; then
|
||||
printf "%-25s %9s%% %10s %10s\n" "$label" "$baseline" "N/A" "—" | tee -a "$LOG_FILE"
|
||||
return
|
||||
fi
|
||||
local delta
|
||||
delta=$(python3 -c "print(f'{$current - $baseline:+.2f}')" 2>/dev/null || echo "?")
|
||||
local marker=""
|
||||
local regression
|
||||
regression=$(python3 -c "print('YES' if $current > $baseline + 0.3 else 'NO')" 2>/dev/null || echo "NO")
|
||||
if [[ "$regression" == "YES" ]]; then
|
||||
marker=" ← REGRESSION"
|
||||
fi
|
||||
printf "%-25s %9s%% %9s%% %9s%%%s\n" "$label" "$baseline" "$current" "$delta" "$marker" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
compare_wer "Parakeet TDT v3 (0.6B)" "$BASELINE_V3_WER" "$V3_WER"
|
||||
compare_wer "Parakeet TDT v2 (0.6B)" "$BASELINE_V2_WER" "$V2_WER"
|
||||
compare_wer "CTC-TDT 110M" "$BASELINE_TDT_CTC_WER" "$TDT_CTC_WER"
|
||||
compare_wer "CTC Earnings" "$BASELINE_EARNINGS_WER" "$EARNINGS_WER"
|
||||
compare_wer "EOU 320ms (120M)" "$BASELINE_EOU_WER" "$EOU_WER"
|
||||
compare_wer "Nemotron 1120ms (0.6B)" "$BASELINE_NEMOTRON_WER" "$NEMOTRON_WER"
|
||||
|
||||
log ""
|
||||
|
||||
# Check for any regressions (>0.3% WER increase)
|
||||
ANY_REGRESSION=$(python3 -c "
|
||||
baselines = [($BASELINE_V3_WER, '$V3_WER'), ($BASELINE_V2_WER, '$V2_WER'), ($BASELINE_TDT_CTC_WER, '$TDT_CTC_WER'), ($BASELINE_EARNINGS_WER, '$EARNINGS_WER'), ($BASELINE_EOU_WER, '$EOU_WER'), ($BASELINE_NEMOTRON_WER, '$NEMOTRON_WER')]
|
||||
for b, c in baselines:
|
||||
if c != 'N/A' and float(c) > b + 0.3:
|
||||
print('YES'); exit()
|
||||
print('NO')
|
||||
" 2>/dev/null || echo "NO")
|
||||
|
||||
if [[ "$ANY_REGRESSION" == "YES" ]]; then
|
||||
log "⚠ WER REGRESSION DETECTED — investigate before merging"
|
||||
else
|
||||
log "✓ No WER regressions (all within 0.3% of baseline)"
|
||||
fi
|
||||
|
||||
# caffeinate will exit automatically since the parent process ($$) exits
|
||||
@@ -1,166 +0,0 @@
|
||||
import Accelerate
|
||||
import CoreML
|
||||
import Foundation
|
||||
import Metal
|
||||
|
||||
/// Neural Engine optimization utilities for ASR pipeline
|
||||
public enum ANEOptimizer {
|
||||
|
||||
// Use shared ANE constants
|
||||
public static let aneAlignment = ANEMemoryUtils.aneAlignment
|
||||
public static let aneTileSize = ANEMemoryUtils.aneTileSize
|
||||
|
||||
/// Create ANE-aligned MLMultiArray with optimized memory layout
|
||||
public static func createANEAlignedArray(
|
||||
shape: [NSNumber],
|
||||
dataType: MLMultiArrayDataType
|
||||
) throws -> MLMultiArray {
|
||||
do {
|
||||
return try ANEMemoryUtils.createAlignedArray(
|
||||
shape: shape,
|
||||
dataType: dataType,
|
||||
zeroClear: false // ASR doesn't need zero-cleared memory
|
||||
)
|
||||
} catch ANEMemoryUtils.ANEMemoryError.allocationFailed {
|
||||
throw NSError(
|
||||
domain: "ANEOptimizer", code: -1,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Failed to allocate ANE-aligned memory"])
|
||||
} catch {
|
||||
throw NSError(
|
||||
domain: "ANEOptimizer", code: -1,
|
||||
userInfo: [NSLocalizedDescriptionKey: "ANE memory allocation error: \(error)"])
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate optimal strides for ANE tile processing
|
||||
public static func calculateOptimalStrides(
|
||||
for shape: [NSNumber],
|
||||
dataType: MLMultiArrayDataType
|
||||
) -> [NSNumber] {
|
||||
return ANEMemoryUtils.calculateOptimalStrides(for: shape)
|
||||
}
|
||||
|
||||
/// Configure optimal compute units for each model type
|
||||
public static func optimalComputeUnits(for modelType: ModelType) -> MLComputeUnits {
|
||||
return .cpuAndNeuralEngine
|
||||
}
|
||||
|
||||
/// Create zero-copy memory view between models
|
||||
public static func createZeroCopyView(
|
||||
from sourceArray: MLMultiArray,
|
||||
shape: [NSNumber],
|
||||
offset: Int = 0
|
||||
) throws -> MLMultiArray {
|
||||
// Ensure we have enough data
|
||||
let sourceElements = sourceArray.shape.map { $0.intValue }.reduce(1, *)
|
||||
let viewElements = shape.map { $0.intValue }.reduce(1, *)
|
||||
|
||||
guard offset + viewElements <= sourceElements else {
|
||||
throw NSError(
|
||||
domain: "ANEOptimizer", code: -2,
|
||||
userInfo: [NSLocalizedDescriptionKey: "View exceeds source array bounds"])
|
||||
}
|
||||
|
||||
// Calculate byte offset
|
||||
let elementSize = ANEMemoryUtils.getElementSize(for: sourceArray.dataType)
|
||||
|
||||
let byteOffset = offset * elementSize
|
||||
let offsetPointer = sourceArray.dataPointer.advanced(by: byteOffset)
|
||||
|
||||
// Create view with same data but new shape
|
||||
return try MLMultiArray(
|
||||
dataPointer: offsetPointer,
|
||||
shape: shape,
|
||||
dataType: sourceArray.dataType,
|
||||
strides: calculateOptimalStrides(for: shape, dataType: sourceArray.dataType),
|
||||
deallocator: nil // No deallocation since it's a view
|
||||
)
|
||||
}
|
||||
|
||||
/// Prefetch data to Neural Engine
|
||||
public static func prefetchToNeuralEngine(_ array: MLMultiArray) {
|
||||
// Trigger ANE prefetch by accessing first and last elements
|
||||
// This causes the ANE to initiate DMA transfer
|
||||
if array.count > 0 {
|
||||
_ = array[0]
|
||||
_ = array[array.count - 1]
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert float32 array to float16 for ANE efficiency
|
||||
public static func convertToFloat16(_ input: MLMultiArray) throws -> MLMultiArray {
|
||||
guard input.dataType == .float32 else {
|
||||
throw NSError(
|
||||
domain: "ANEOptimizer", code: -3,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Input must be float32"])
|
||||
}
|
||||
|
||||
// Create float16 array with ANE alignment
|
||||
let float16Array = try createANEAlignedArray(
|
||||
shape: input.shape,
|
||||
dataType: .float16
|
||||
)
|
||||
|
||||
// Convert using Accelerate with platform-specific handling
|
||||
let sourcePtr = input.dataPointer.bindMemory(to: Float.self, capacity: input.count)
|
||||
|
||||
var sourceBuffer = vImage_Buffer(
|
||||
data: sourcePtr,
|
||||
height: 1,
|
||||
width: vImagePixelCount(input.count),
|
||||
rowBytes: input.count * MemoryLayout<Float>.stride
|
||||
)
|
||||
|
||||
// Use UInt16 as storage type for cross-platform compatibility
|
||||
let destPtr = float16Array.dataPointer.bindMemory(to: UInt16.self, capacity: input.count)
|
||||
|
||||
var destBuffer = vImage_Buffer(
|
||||
data: destPtr,
|
||||
height: 1,
|
||||
width: vImagePixelCount(input.count),
|
||||
rowBytes: input.count * MemoryLayout<UInt16>.stride
|
||||
)
|
||||
|
||||
vImageConvert_PlanarFtoPlanar16F(&sourceBuffer, &destBuffer, 0)
|
||||
|
||||
return float16Array
|
||||
}
|
||||
|
||||
/// Model type enumeration for compute unit selection
|
||||
public enum ModelType {
|
||||
case encoder
|
||||
case decoder
|
||||
case joint
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension for MLFeatureProvider to enable zero-copy chaining
|
||||
public class ZeroCopyFeatureProvider: NSObject, MLFeatureProvider {
|
||||
private let features: [String: MLFeatureValue]
|
||||
|
||||
public init(features: [String: MLFeatureValue]) {
|
||||
self.features = features
|
||||
super.init()
|
||||
}
|
||||
|
||||
public var featureNames: Set<String> {
|
||||
Set(features.keys)
|
||||
}
|
||||
|
||||
public func featureValue(for featureName: String) -> MLFeatureValue? {
|
||||
features[featureName]
|
||||
}
|
||||
|
||||
/// Create a provider that chains output from one model to input of another
|
||||
public static func chain(
|
||||
from outputProvider: MLFeatureProvider,
|
||||
outputName: String,
|
||||
to inputName: String
|
||||
) -> ZeroCopyFeatureProvider? {
|
||||
guard let outputValue = outputProvider.featureValue(for: outputName) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ZeroCopyFeatureProvider(features: [inputName: outputValue])
|
||||
}
|
||||
}
|
||||
@@ -3,11 +3,6 @@ import AVFoundation
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
public enum AudioSource: Sendable {
|
||||
case microphone
|
||||
case system
|
||||
}
|
||||
|
||||
public actor AsrManager {
|
||||
|
||||
internal let logger = AppLogger(category: "ASR")
|
||||
@@ -24,14 +19,12 @@ public actor AsrManager {
|
||||
|
||||
internal let progressEmitter = ProgressEmitter()
|
||||
|
||||
/// Get the number of decoder layers for the current model.
|
||||
/// Number of decoder layers for the current model.
|
||||
/// Returns 2 if models not loaded (v2/v3 default, tdtCtc110m uses 1).
|
||||
internal func getDecoderLayers() -> Int {
|
||||
return asrModels?.version.decoderLayers ?? 2
|
||||
internal var decoderLayerCount: Int {
|
||||
asrModels?.version.decoderLayers ?? 2
|
||||
}
|
||||
|
||||
/// Token duration optimization model
|
||||
|
||||
/// Cached vocabulary loaded once during initialization
|
||||
internal var vocabulary: [Int: String] = [:]
|
||||
#if DEBUG
|
||||
@@ -41,17 +34,25 @@ public actor AsrManager {
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO:: the decoder state should be moved higher up in the API interface
|
||||
// Per-source decoder states are actor-internal; callers reset via resetDecoderState().
|
||||
internal var microphoneDecoderState: TdtDecoderState
|
||||
internal var systemDecoderState: TdtDecoderState
|
||||
|
||||
// Vocabulary boosting state (configured via configureVocabularyBoosting)
|
||||
// Internal access required for AsrTranscription extension (separate file)
|
||||
internal var customVocabulary: CustomVocabularyContext?
|
||||
internal var ctcSpotter: CtcKeywordSpotter?
|
||||
internal var vocabularyRescorer: VocabularyRescorer?
|
||||
internal var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig?
|
||||
internal var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil }
|
||||
/// Get decoder state for a given audio source.
|
||||
internal func decoderState(for source: AudioSource) -> TdtDecoderState {
|
||||
switch source {
|
||||
case .microphone: return microphoneDecoderState
|
||||
case .system: return systemDecoderState
|
||||
}
|
||||
}
|
||||
|
||||
/// Set decoder state for a given audio source.
|
||||
internal func setDecoderState(_ state: TdtDecoderState, for source: AudioSource) {
|
||||
switch source {
|
||||
case .microphone: microphoneDecoderState = state
|
||||
case .system: systemDecoderState = state
|
||||
}
|
||||
}
|
||||
|
||||
// Cached CTC logits from fused Preprocessor (unified custom vocabulary)
|
||||
internal var cachedCtcLogits: MLMultiArray?
|
||||
@@ -61,6 +62,13 @@ public actor AsrManager {
|
||||
/// Whether the Preprocessor outputs CTC logits (unified custom vocabulary model).
|
||||
public var hasCachedCtcLogits: Bool { cachedCtcLogits != nil }
|
||||
|
||||
/// Clear all cached CTC data (logits, frame duration, valid frames).
|
||||
internal func clearCachedCtcData() {
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
}
|
||||
|
||||
/// Get cached CTC raw logits as [[Float]] for external use (e.g. benchmarks).
|
||||
/// These are raw logits — callers must apply `CtcKeywordSpotter.applyLogSoftmax()`
|
||||
/// to convert to log-probabilities before use in keyword detection.
|
||||
@@ -120,7 +128,7 @@ public actor AsrManager {
|
||||
/// Only one session is supported at a time.
|
||||
public var transcriptionProgressStream: AsyncThrowingStream<Double, Error> {
|
||||
get async {
|
||||
await progressEmitter.currentStream()
|
||||
await progressEmitter.ensureSession()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,55 +165,6 @@ public actor AsrManager {
|
||||
logger.info("AsrManager initialized successfully with provided models")
|
||||
}
|
||||
|
||||
/// Configure vocabulary boosting for batch transcription.
|
||||
///
|
||||
/// When configured, vocabulary terms will be automatically rescored after each `transcribe()` call
|
||||
/// using CTC-based constrained decoding. The resulting `ASRResult` will have `ctcDetectedTerms`
|
||||
/// and `ctcAppliedTerms` populated.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - vocabulary: Custom vocabulary context with terms to detect
|
||||
/// - ctcModels: Pre-loaded CTC models for keyword spotting
|
||||
/// - config: Optional rescorer configuration (default: vocabulary-size-aware config)
|
||||
/// - Throws: Error if rescorer initialization fails
|
||||
public func configureVocabularyBoosting(
|
||||
vocabulary: CustomVocabularyContext,
|
||||
ctcModels: CtcModels,
|
||||
config: VocabularyRescorer.Config? = nil
|
||||
) async throws {
|
||||
self.customVocabulary = vocabulary
|
||||
|
||||
let blankId = ctcModels.vocabulary.count
|
||||
self.ctcSpotter = CtcKeywordSpotter(models: ctcModels, blankId: blankId)
|
||||
|
||||
let vocabSize = vocabulary.terms.count
|
||||
let vocabConfig = ContextBiasingConstants.rescorerConfig(forVocabSize: vocabSize)
|
||||
self.vocabSizeConfig = vocabConfig
|
||||
let effectiveConfig = config ?? .default
|
||||
|
||||
let ctcModelDir = CtcModels.defaultCacheDirectory(for: ctcModels.variant)
|
||||
self.vocabularyRescorer = try await VocabularyRescorer.create(
|
||||
spotter: ctcSpotter!,
|
||||
vocabulary: vocabulary,
|
||||
config: effectiveConfig,
|
||||
ctcModelDirectory: ctcModelDir
|
||||
)
|
||||
|
||||
let isLargeVocab = vocabSize > ContextBiasingConstants.largeVocabThreshold
|
||||
logger.info(
|
||||
"Vocabulary boosting configured with \(vocabSize) terms (isLargeVocab: \(isLargeVocab))"
|
||||
)
|
||||
}
|
||||
|
||||
/// Disable vocabulary boosting and release CTC models.
|
||||
public func disableVocabularyBoosting() {
|
||||
customVocabulary = nil
|
||||
ctcSpotter = nil
|
||||
vocabularyRescorer = nil
|
||||
vocabSizeConfig = nil
|
||||
logger.info("Vocabulary boosting disabled")
|
||||
}
|
||||
|
||||
private func createFeatureProvider(
|
||||
features: [(name: String, array: MLMultiArray)]
|
||||
) throws
|
||||
@@ -275,16 +234,7 @@ public actor AsrManager {
|
||||
throw ASRError.notInitialized
|
||||
}
|
||||
|
||||
// Get the appropriate decoder state
|
||||
var state: TdtDecoderState
|
||||
switch source {
|
||||
case .microphone:
|
||||
state = microphoneDecoderState
|
||||
case .system:
|
||||
state = systemDecoderState
|
||||
}
|
||||
|
||||
// Reset the existing decoder state to clear all cached values including predictorOutput
|
||||
var state = decoderState(for: source)
|
||||
state.reset()
|
||||
|
||||
let initDecoderInput = try prepareDecoderInput(
|
||||
@@ -298,40 +248,7 @@ public actor AsrManager {
|
||||
)
|
||||
|
||||
state.update(from: initDecoderOutput)
|
||||
|
||||
// Store back
|
||||
switch source {
|
||||
case .microphone:
|
||||
microphoneDecoderState = state
|
||||
case .system:
|
||||
systemDecoderState = state
|
||||
}
|
||||
}
|
||||
|
||||
private func loadModel(
|
||||
path: URL,
|
||||
name: String,
|
||||
configuration: MLModelConfiguration
|
||||
) async throws -> MLModel {
|
||||
do {
|
||||
let model = try MLModel(contentsOf: path, configuration: configuration)
|
||||
return model
|
||||
} catch {
|
||||
logger.error("Failed to load \(name) model: \(error)")
|
||||
|
||||
throw ASRError.modelLoadFailed
|
||||
}
|
||||
}
|
||||
private static func getDefaultModelsDirectory() -> URL {
|
||||
let applicationSupportURL = FileManager.default.urls(
|
||||
for: .applicationSupportDirectory, in: .userDomainMask
|
||||
).first!
|
||||
let appDirectory = applicationSupportURL.appendingPathComponent(
|
||||
"FluidAudio", isDirectory: true)
|
||||
let directory = appDirectory.appendingPathComponent("Models/Parakeet", isDirectory: true)
|
||||
|
||||
try? FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true)
|
||||
return directory.standardizedFileURL
|
||||
setDecoderState(state, for: source)
|
||||
}
|
||||
|
||||
public func resetState() {
|
||||
@@ -339,9 +256,7 @@ public actor AsrManager {
|
||||
let layers = asrModels?.version.decoderLayers ?? 2
|
||||
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
|
||||
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
clearCachedCtcData()
|
||||
Task { await sharedMLArrayCache.clear() }
|
||||
}
|
||||
|
||||
@@ -356,11 +271,7 @@ public actor AsrManager {
|
||||
// Reset decoder states using fresh allocations for deterministic behavior
|
||||
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
|
||||
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
|
||||
// Release vocabulary boosting resources and cached CTC data
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
disableVocabularyBoosting()
|
||||
clearCachedCtcData()
|
||||
Task { await sharedMLArrayCache.clear() }
|
||||
logger.info("AsrManager resources cleaned up")
|
||||
}
|
||||
@@ -467,7 +378,7 @@ public actor AsrManager {
|
||||
let estimatedSamples = Int((Double(audioFile.length) * sampleRateRatio).rounded(.up))
|
||||
|
||||
if estimatedSamples > config.streamingThreshold {
|
||||
return try await transcribeStreaming(url, source: source)
|
||||
return try await transcribeDiskBacked(url, source: source)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -476,30 +387,30 @@ public actor AsrManager {
|
||||
return result
|
||||
}
|
||||
|
||||
/// Transcribe audio from a file URL using streaming mode.
|
||||
/// Transcribe audio from a file URL using disk-backed chunked processing.
|
||||
///
|
||||
/// Memory-efficient transcription that processes audio in chunks, maintaining constant
|
||||
/// memory usage (~1.2MB) regardless of file size. Ideal for long audio files.
|
||||
/// Memory-efficient transcription that memory-maps the file and processes audio in chunks,
|
||||
/// maintaining constant memory usage (~1.2MB) regardless of file size. Ideal for long audio files.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - url: The URL to the audio file
|
||||
/// - source: The audio source type (defaults to .system)
|
||||
/// - Returns: An ASRResult containing the transcribed text and token timings
|
||||
/// - Throws: ASRError if transcription fails, models are not initialized, or the file cannot be read
|
||||
public func transcribeStreaming(_ url: URL, source: AudioSource = .system) async throws -> ASRResult {
|
||||
public func transcribeDiskBacked(_ url: URL, source: AudioSource = .system) async throws -> ASRResult {
|
||||
guard isAvailable else { throw ASRError.notInitialized }
|
||||
|
||||
let startTime = Date()
|
||||
|
||||
// Create a disk-backed source for memory-efficient access
|
||||
let factory = StreamingAudioSourceFactory()
|
||||
let factory = AudioSourceFactory()
|
||||
let (sampleSource, _) = try factory.makeDiskBackedSource(
|
||||
from: url,
|
||||
targetSampleRate: config.sampleRate
|
||||
)
|
||||
|
||||
let totalSamples = sampleSource.sampleCount
|
||||
guard totalSamples >= 16_000 else {
|
||||
guard totalSamples >= config.sampleRate else {
|
||||
sampleSource.cleanup()
|
||||
throw ASRError.invalidAudioData
|
||||
}
|
||||
@@ -589,53 +500,18 @@ public actor AsrManager {
|
||||
try await initializeDecoderState(for: source)
|
||||
}
|
||||
|
||||
internal func normalizedTimingToken(_ token: String) -> String {
|
||||
nonisolated internal func normalizedTimingToken(_ token: String) -> String {
|
||||
token.replacingOccurrences(of: "▁", with: " ")
|
||||
}
|
||||
|
||||
internal func convertTokensWithExistingTimings(
|
||||
_ tokenIds: [Int], timings: [TokenTiming]
|
||||
) -> (
|
||||
text: String, timings: [TokenTiming]
|
||||
) {
|
||||
guard !tokenIds.isEmpty else { return ("", []) }
|
||||
/// Decode token IDs to text using SentencePiece conventions.
|
||||
internal func convertTokensToText(_ tokenIds: [Int]) -> String {
|
||||
guard !tokenIds.isEmpty else { return "" }
|
||||
|
||||
// SentencePiece-compatible decoding algorithm:
|
||||
// 1. Convert token IDs to token strings
|
||||
var tokens: [String] = []
|
||||
var tokenInfos: [(token: String, tokenId: Int, timing: TokenTiming?)] = []
|
||||
|
||||
for (index, tokenId) in tokenIds.enumerated() {
|
||||
if let token = vocabulary[tokenId], !token.isEmpty {
|
||||
tokens.append(token)
|
||||
let timing = index < timings.count ? timings[index] : nil
|
||||
tokenInfos.append((token: token, tokenId: tokenId, timing: timing))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Concatenate all tokens (this is how SentencePiece works)
|
||||
let concatenated = tokens.joined()
|
||||
|
||||
// 3. Replace ▁ with space (SentencePiece standard)
|
||||
let text = concatenated.replacingOccurrences(of: "▁", with: " ")
|
||||
let tokens = tokenIds.compactMap { vocabulary[$0] }.filter { !$0.isEmpty }
|
||||
return tokens.joined()
|
||||
.replacingOccurrences(of: "▁", with: " ")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
// 4. For now, return original timings as-is
|
||||
// Note: Proper timing alignment would require tracking character positions
|
||||
// through the concatenation and replacement process
|
||||
let adjustedTimings = tokenInfos.compactMap { info in
|
||||
info.timing.map { timing in
|
||||
TokenTiming(
|
||||
token: normalizedTimingToken(info.token),
|
||||
tokenId: info.tokenId,
|
||||
startTime: timing.startTime,
|
||||
endTime: timing.endTime,
|
||||
confidence: timing.confidence
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return (text, adjustedTimings)
|
||||
}
|
||||
|
||||
nonisolated internal func extractFeatureValue(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
@preconcurrency import CoreML
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// ASR model version enum
|
||||
public enum AsrModelVersion: Sendable {
|
||||
@@ -331,8 +330,6 @@ extension AsrModels {
|
||||
return try await load(from: targetDir, configuration: configuration, progressHandler: progressHandler)
|
||||
}
|
||||
|
||||
/// Load models with ANE-optimized configurations
|
||||
|
||||
private static func describeComputeUnits(_ units: MLComputeUnits) -> String {
|
||||
switch units {
|
||||
case .cpuOnly:
|
||||
@@ -348,41 +345,20 @@ extension AsrModels {
|
||||
}
|
||||
}
|
||||
|
||||
public static func loadWithANEOptimization(
|
||||
from directory: URL? = nil,
|
||||
enableFP16: Bool = true
|
||||
) async throws -> AsrModels {
|
||||
let targetDir = directory ?? defaultCacheDirectory()
|
||||
|
||||
logger.info("Loading ASR models with ANE optimization from: \(targetDir.path)")
|
||||
|
||||
// Use the load method that already applies per-model optimizations
|
||||
return try await load(from: targetDir, configuration: nil)
|
||||
}
|
||||
|
||||
public static func defaultConfiguration() -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
config.allowLowPrecisionAccumulationOnGPU = true
|
||||
// Prefer Neural Engine across platforms for ASR inference to avoid GPU dispatch.
|
||||
config.computeUnits = .cpuAndNeuralEngine
|
||||
return config
|
||||
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
/// Create optimized configuration for specific model type
|
||||
/// Create optimized configuration for model inference
|
||||
public static func optimizedConfiguration(
|
||||
for modelType: ANEOptimizer.ModelType,
|
||||
enableFP16: Bool = true
|
||||
) -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
config.allowLowPrecisionAccumulationOnGPU = enableFP16
|
||||
config.computeUnits = ANEOptimizer.optimalComputeUnits(for: modelType)
|
||||
|
||||
// Enable model-specific optimizations
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
if isCI {
|
||||
config.computeUnits = .cpuOnly
|
||||
}
|
||||
|
||||
let config = MLModelConfigurationUtils.defaultConfiguration(
|
||||
computeUnits: isCI ? .cpuOnly : .cpuAndNeuralEngine
|
||||
)
|
||||
config.allowLowPrecisionAccumulationOnGPU = enableFP16
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -536,14 +512,7 @@ extension AsrModels {
|
||||
}
|
||||
|
||||
public static func defaultCacheDirectory(for version: AsrModelVersion = .v3) -> URL {
|
||||
let appSupport = FileManager.default.urls(
|
||||
for: .applicationSupportDirectory, in: .userDomainMask
|
||||
).first!
|
||||
return
|
||||
appSupport
|
||||
.appendingPathComponent("FluidAudio", isDirectory: true)
|
||||
.appendingPathComponent("Models", isDirectory: true)
|
||||
.appendingPathComponent(version.repo.folderName, isDirectory: true)
|
||||
MLModelConfigurationUtils.defaultModelsDirectory(for: version.repo)
|
||||
}
|
||||
|
||||
// Legacy method for backward compatibility
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
@preconcurrency import CoreML
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
extension AsrManager {
|
||||
|
||||
@@ -8,34 +7,15 @@ extension AsrManager {
|
||||
_ audioSamples: [Float], source: AudioSource
|
||||
) async throws -> ASRResult {
|
||||
guard isAvailable else { throw ASRError.notInitialized }
|
||||
guard audioSamples.count >= 16_000 else { throw ASRError.invalidAudioData }
|
||||
guard audioSamples.count >= config.sampleRate else { throw ASRError.invalidAudioData }
|
||||
|
||||
let startTime = Date()
|
||||
|
||||
// Get the appropriate decoder state
|
||||
var decoderState: TdtDecoderState
|
||||
switch source {
|
||||
case .microphone:
|
||||
decoderState = microphoneDecoderState
|
||||
case .system:
|
||||
decoderState = systemDecoderState
|
||||
}
|
||||
var decoderState = decoderState(for: source)
|
||||
|
||||
// Route to appropriate processing method based on audio length
|
||||
if audioSamples.count <= ASRConstants.maxModelSamples {
|
||||
let originalLength = audioSamples.count
|
||||
let frameAlignedCandidate =
|
||||
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
|
||||
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
|
||||
let frameAlignedLength: Int
|
||||
let alignedSamples: [Float]
|
||||
if frameAlignedCandidate > originalLength && frameAlignedCandidate <= ASRConstants.maxModelSamples {
|
||||
frameAlignedLength = frameAlignedCandidate
|
||||
alignedSamples = audioSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
|
||||
} else {
|
||||
frameAlignedLength = originalLength
|
||||
alignedSamples = audioSamples
|
||||
}
|
||||
let (alignedSamples, frameAlignedLength) = frameAlignedAudio(audioSamples)
|
||||
let paddedAudio: [Float] = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
|
||||
let (hypothesis, encoderSequenceLength) = try await executeMLInferenceWithTimings(
|
||||
paddedAudio,
|
||||
@@ -45,7 +25,7 @@ extension AsrManager {
|
||||
isLastChunk: true // Single-chunk: always first and last
|
||||
)
|
||||
|
||||
var result = processTranscriptionResult(
|
||||
let result = processTranscriptionResult(
|
||||
tokenIds: hypothesis.ySequence,
|
||||
timestamps: hypothesis.timestamps,
|
||||
confidences: hypothesis.tokenConfidences,
|
||||
@@ -55,25 +35,14 @@ extension AsrManager {
|
||||
processingTime: Date().timeIntervalSince(startTime)
|
||||
)
|
||||
|
||||
// Auto-apply vocabulary rescoring when configured
|
||||
if vocabBoostingEnabled {
|
||||
result = await applyVocabularyRescoring(result: result, audioSamples: audioSamples)
|
||||
}
|
||||
|
||||
// Store decoder state back
|
||||
switch source {
|
||||
case .microphone:
|
||||
microphoneDecoderState = decoderState
|
||||
case .system:
|
||||
systemDecoderState = decoderState
|
||||
}
|
||||
setDecoderState(decoderState, for: source)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ChunkProcessor handles stateless chunked transcription for long audio
|
||||
let processor = ChunkProcessor(audioSamples: audioSamples)
|
||||
var result = try await processor.process(
|
||||
let result = try await processor.process(
|
||||
using: self,
|
||||
startTime: startTime,
|
||||
progressHandler: { [weak self] progress in
|
||||
@@ -82,18 +51,7 @@ extension AsrManager {
|
||||
}
|
||||
)
|
||||
|
||||
// Auto-apply vocabulary rescoring when configured
|
||||
if vocabBoostingEnabled {
|
||||
result = await applyVocabularyRescoring(result: result, audioSamples: audioSamples)
|
||||
}
|
||||
|
||||
// Store decoder state back (ChunkProcessor uses the stored state directly)
|
||||
switch source {
|
||||
case .microphone:
|
||||
microphoneDecoderState = decoderState
|
||||
case .system:
|
||||
systemDecoderState = decoderState
|
||||
}
|
||||
setDecoderState(decoderState, for: source)
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -167,20 +125,14 @@ extension AsrManager {
|
||||
cachedCtcFrameDuration = 0.04 // 40ms per frame
|
||||
cachedCtcValidFrames = encoderSequenceLength
|
||||
} else {
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
clearCachedCtcData()
|
||||
}
|
||||
} catch {
|
||||
logger.warning("CTC head inference failed: \(error.localizedDescription)")
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
clearCachedCtcData()
|
||||
}
|
||||
} else {
|
||||
cachedCtcLogits = nil
|
||||
cachedCtcFrameDuration = nil
|
||||
cachedCtcValidFrames = nil
|
||||
clearCachedCtcData()
|
||||
}
|
||||
|
||||
// Calculate actual audio frames if not provided using shared constants
|
||||
@@ -250,33 +202,18 @@ extension AsrManager {
|
||||
return try MLDictionaryFeatureProvider(dictionary: features)
|
||||
}
|
||||
|
||||
/// Streaming-friendly chunk transcription that preserves decoder state and supports start-frame offset.
|
||||
/// This is used by both sliding window chunking and streaming paths to unify behavior.
|
||||
public func transcribeStreamingChunk(
|
||||
/// Chunk transcription that preserves decoder state between calls.
|
||||
/// Used by SlidingWindowAsrManager for overlapping-window processing with token deduplication.
|
||||
public func transcribeChunk(
|
||||
_ chunkSamples: [Float],
|
||||
source: AudioSource,
|
||||
previousTokens: [Int] = [],
|
||||
isLastChunk: Bool = false
|
||||
) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], encoderSequenceLength: Int) {
|
||||
// Select and copy decoder state for the source
|
||||
var state = (source == .microphone) ? microphoneDecoderState : systemDecoderState
|
||||
var state = decoderState(for: source)
|
||||
|
||||
let originalLength = chunkSamples.count
|
||||
let frameAlignedCandidate =
|
||||
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
|
||||
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
|
||||
let frameAlignedLength: Int
|
||||
let alignedSamples: [Float]
|
||||
if previousTokens.isEmpty
|
||||
&& frameAlignedCandidate > originalLength
|
||||
&& frameAlignedCandidate <= ASRConstants.maxModelSamples
|
||||
{
|
||||
frameAlignedLength = frameAlignedCandidate
|
||||
alignedSamples = chunkSamples + Array(repeating: 0, count: frameAlignedLength - originalLength)
|
||||
} else {
|
||||
frameAlignedLength = originalLength
|
||||
alignedSamples = chunkSamples
|
||||
}
|
||||
let (alignedSamples, frameAlignedLength) = frameAlignedAudio(
|
||||
chunkSamples, allowAlignment: previousTokens.isEmpty)
|
||||
let padded = padAudioIfNeeded(alignedSamples, targetLength: ASRConstants.maxModelSamples)
|
||||
let (hypothesis, encLen) = try await executeMLInferenceWithTimings(
|
||||
padded,
|
||||
@@ -287,12 +224,7 @@ extension AsrManager {
|
||||
isLastChunk: isLastChunk
|
||||
)
|
||||
|
||||
// Persist updated state back to the source-specific slot
|
||||
if source == .microphone {
|
||||
microphoneDecoderState = state
|
||||
} else {
|
||||
systemDecoderState = state
|
||||
}
|
||||
setDecoderState(state, for: source)
|
||||
|
||||
// Apply token deduplication if previous tokens are provided
|
||||
if !previousTokens.isEmpty && hypothesis.hasTokens {
|
||||
@@ -317,23 +249,16 @@ extension AsrManager {
|
||||
tokenDurations: [Int] = [],
|
||||
encoderSequenceLength: Int,
|
||||
audioSamples: [Float],
|
||||
processingTime: TimeInterval,
|
||||
tokenTimings: [TokenTiming] = []
|
||||
processingTime: TimeInterval
|
||||
) -> ASRResult {
|
||||
|
||||
let (text, finalTimings) = convertTokensWithExistingTimings(tokenIds, timings: tokenTimings)
|
||||
let text = convertTokensToText(tokenIds)
|
||||
let duration = TimeInterval(audioSamples.count) / TimeInterval(config.sampleRate)
|
||||
|
||||
// Convert timestamps to TokenTiming objects if provided
|
||||
let timingsFromTimestamps = createTokenTimings(
|
||||
let resultTimings = createTokenTimings(
|
||||
from: tokenIds, timestamps: timestamps, confidences: confidences, tokenDurations: tokenDurations)
|
||||
|
||||
// Use existing timings if provided, otherwise use timings from timestamps
|
||||
let resultTimings = tokenTimings.isEmpty ? timingsFromTimestamps : finalTimings
|
||||
|
||||
// Calculate confidence based on actual model confidence scores from TDT decoder
|
||||
let confidence = calculateConfidence(
|
||||
duration: duration,
|
||||
tokenCount: tokenIds.count,
|
||||
isEmpty: text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty,
|
||||
tokenConfidences: confidences
|
||||
@@ -348,6 +273,27 @@ extension AsrManager {
|
||||
)
|
||||
}
|
||||
|
||||
/// Align audio samples to encoder frame boundaries by zero-padding to the next frame boundary.
|
||||
/// Returns the aligned samples and the frame-aligned length.
|
||||
/// - Parameters:
|
||||
/// - audioSamples: Raw audio samples
|
||||
/// - allowAlignment: When false, skip alignment (e.g. when previous context exists)
|
||||
nonisolated internal func frameAlignedAudio(
|
||||
_ audioSamples: [Float], allowAlignment: Bool = true
|
||||
) -> (samples: [Float], frameAlignedLength: Int) {
|
||||
let originalLength = audioSamples.count
|
||||
let frameAlignedCandidate =
|
||||
((originalLength + ASRConstants.samplesPerEncoderFrame - 1)
|
||||
/ ASRConstants.samplesPerEncoderFrame) * ASRConstants.samplesPerEncoderFrame
|
||||
if allowAlignment && frameAlignedCandidate > originalLength
|
||||
&& frameAlignedCandidate <= ASRConstants.maxModelSamples
|
||||
{
|
||||
let aligned = audioSamples + Array(repeating: 0, count: frameAlignedCandidate - originalLength)
|
||||
return (aligned, frameAlignedCandidate)
|
||||
}
|
||||
return (audioSamples, originalLength)
|
||||
}
|
||||
|
||||
nonisolated internal func padAudioIfNeeded(_ audioSamples: [Float], targetLength: Int) -> [Float] {
|
||||
guard audioSamples.count < targetLength else { return audioSamples }
|
||||
return audioSamples + Array(repeating: 0, count: targetLength - audioSamples.count)
|
||||
@@ -356,8 +302,8 @@ extension AsrManager {
|
||||
/// Calculate confidence score based purely on TDT model token confidence scores
|
||||
/// Returns the average of token-level softmax probabilities from the decoder
|
||||
/// Range: 0.1 (empty transcription) to 1.0 (perfect confidence)
|
||||
private func calculateConfidence(
|
||||
duration: Double, tokenCount: Int, isEmpty: Bool, tokenConfidences: [Float]
|
||||
nonisolated private func calculateConfidence(
|
||||
tokenCount: Int, isEmpty: Bool, tokenConfidences: [Float]
|
||||
) -> Float {
|
||||
// Empty transcription gets low confidence
|
||||
if isEmpty {
|
||||
@@ -401,27 +347,25 @@ extension AsrManager {
|
||||
// Sort by timestamp to ensure chronological order
|
||||
let sortedData = combinedData.sorted { $0.timestamp < $1.timestamp }
|
||||
|
||||
let frameDuration = ASRConstants.secondsPerEncoderFrame
|
||||
|
||||
for i in 0..<sortedData.count {
|
||||
let data = sortedData[i]
|
||||
let tokenId = data.tokenId
|
||||
let frameIndex = data.timestamp
|
||||
|
||||
// Convert encoder frame index to time (80ms per frame)
|
||||
let startTime = TimeInterval(frameIndex) * 0.08
|
||||
let startTime = TimeInterval(frameIndex) * frameDuration
|
||||
|
||||
// Calculate end time using actual token duration if available
|
||||
let endTime: TimeInterval
|
||||
if !tokenDurations.isEmpty && data.duration > 0 {
|
||||
// Use actual token duration (convert frames to time: duration * 0.08)
|
||||
let durationInSeconds = TimeInterval(data.duration) * 0.08
|
||||
endTime = startTime + max(durationInSeconds, 0.08) // Minimum 80ms duration
|
||||
let durationInSeconds = TimeInterval(data.duration) * frameDuration
|
||||
endTime = startTime + max(durationInSeconds, frameDuration)
|
||||
} else if i < sortedData.count - 1 {
|
||||
// Fallback: Use next token's start time as this token's end time
|
||||
let nextStartTime = TimeInterval(sortedData[i + 1].timestamp) * 0.08
|
||||
endTime = max(nextStartTime, startTime + 0.08) // Ensure end > start
|
||||
let nextStartTime = TimeInterval(sortedData[i + 1].timestamp) * frameDuration
|
||||
endTime = max(nextStartTime, startTime + frameDuration)
|
||||
} else {
|
||||
// Last token: assume minimum duration
|
||||
endTime = startTime + 0.08
|
||||
endTime = startTime + frameDuration
|
||||
}
|
||||
|
||||
// Validate that end time is after start time
|
||||
@@ -447,43 +391,12 @@ extension AsrManager {
|
||||
return timings
|
||||
}
|
||||
|
||||
/// Slice encoder output to remove left context frames (following NeMo approach)
|
||||
private func sliceEncoderOutput(
|
||||
_ encoderOutput: MLMultiArray,
|
||||
from startFrame: Int,
|
||||
newLength: Int
|
||||
) throws -> MLMultiArray {
|
||||
let shape = encoderOutput.shape
|
||||
let batchSize = shape[0].intValue
|
||||
let hiddenSize = shape[2].intValue
|
||||
|
||||
// Create new array with sliced dimensions
|
||||
let slicedArray = try MLMultiArray(
|
||||
shape: [batchSize, newLength, hiddenSize] as [NSNumber],
|
||||
dataType: encoderOutput.dataType
|
||||
)
|
||||
|
||||
// Copy data from startFrame onwards
|
||||
let sourcePtr = encoderOutput.dataPointer.bindMemory(to: Float.self, capacity: encoderOutput.count)
|
||||
let destPtr = slicedArray.dataPointer.bindMemory(to: Float.self, capacity: slicedArray.count)
|
||||
|
||||
for t in 0..<newLength {
|
||||
for h in 0..<hiddenSize {
|
||||
let sourceIndex = (startFrame + t) * hiddenSize + h
|
||||
let destIndex = t * hiddenSize + h
|
||||
destPtr[destIndex] = sourcePtr[sourceIndex]
|
||||
}
|
||||
}
|
||||
|
||||
return slicedArray
|
||||
}
|
||||
|
||||
/// Remove duplicate token sequences at the start of the current list that overlap
|
||||
/// with the tail of the previous accumulated tokens. Returns deduplicated current tokens
|
||||
/// and the number of removed leading tokens so caller can drop aligned timestamps.
|
||||
/// Ideally this is not needed. We need to make some more fixes to the TDT decoding logic,
|
||||
/// this should be a temporary workaround.
|
||||
internal func removeDuplicateTokenSequence(
|
||||
nonisolated internal func removeDuplicateTokenSequence(
|
||||
previous: [Int], current: [Int], maxOverlap: Int = 12
|
||||
) -> (deduped: [Int], removedCount: Int) {
|
||||
|
||||
@@ -551,139 +464,4 @@ extension AsrManager {
|
||||
return (workingCurrent, removedCount)
|
||||
}
|
||||
|
||||
/// Calculate start frame offset for a sliding window segment (deprecated - now handled by timeJump)
|
||||
nonisolated internal func calculateStartFrameOffset(segmentIndex: Int, leftContextSeconds: Double) -> Int {
|
||||
// This method is deprecated as frame tracking is now handled by the decoder's timeJump mechanism
|
||||
// Kept for test compatibility
|
||||
return 0
|
||||
}
|
||||
|
||||
// MARK: - Vocabulary Rescoring
|
||||
|
||||
/// Apply vocabulary rescoring to an ASRResult using CTC-based constrained decoding.
|
||||
///
|
||||
/// Runs CTC inference on the audio samples and applies vocabulary rescoring to correct
|
||||
/// misrecognized words. Returns an updated ASRResult with rescored text and populated
|
||||
/// `ctcDetectedTerms`/`ctcAppliedTerms` fields.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - result: The original ASRResult from transcription
|
||||
/// - audioSamples: Audio samples used for CTC inference
|
||||
/// - Returns: An ASRResult with rescored text and CTC metadata, or the original result if rescoring was skipped
|
||||
internal func applyVocabularyRescoring(
|
||||
result: ASRResult, audioSamples: [Float]
|
||||
) async -> ASRResult {
|
||||
guard let rescorer = vocabularyRescorer,
|
||||
let vocab = customVocabulary,
|
||||
let tokenTimings = result.tokenTimings, !tokenTimings.isEmpty
|
||||
else {
|
||||
return result
|
||||
}
|
||||
|
||||
do {
|
||||
// Try to use cached CTC logits from unified Preprocessor first
|
||||
let logProbs: [[Float]]
|
||||
let frameDuration: Double
|
||||
|
||||
if let cached = cachedCtcLogits, let duration = cachedCtcFrameDuration {
|
||||
// Convert MLMultiArray to [[Float]]
|
||||
logProbs = convertCtcLogitsToArray(cached)
|
||||
frameDuration = duration
|
||||
logger.debug("Using cached CTC logits from Preprocessor (unified model)")
|
||||
} else if let spotter = ctcSpotter {
|
||||
// Fallback: run separate CTC encoder
|
||||
let spotResult = try await spotter.spotKeywordsWithLogProbs(
|
||||
audioSamples: audioSamples,
|
||||
customVocabulary: vocab,
|
||||
minScore: nil
|
||||
)
|
||||
logProbs = spotResult.logProbs
|
||||
frameDuration = spotResult.frameDuration
|
||||
logger.debug("Using separate CTC encoder (legacy dual-model approach)")
|
||||
} else {
|
||||
logger.warning("Vocabulary rescoring skipped: no CTC logits available")
|
||||
return result
|
||||
}
|
||||
|
||||
guard !logProbs.isEmpty else {
|
||||
logger.debug("Vocabulary rescoring skipped: no log probs from CTC")
|
||||
return result
|
||||
}
|
||||
|
||||
let vocabConfig = vocabSizeConfig ?? ContextBiasingConstants.rescorerConfig(forVocabSize: 0)
|
||||
// Use the higher of the size-based default and the caller-specified threshold
|
||||
// so that CustomVocabularyContext.minSimilarity is respected when stricter.
|
||||
let effectiveMinSimilarity = max(vocabConfig.minSimilarity, vocab.minSimilarity)
|
||||
|
||||
let rescoreOutput = rescorer.ctcTokenRescore(
|
||||
transcript: result.text,
|
||||
tokenTimings: tokenTimings,
|
||||
logProbs: logProbs,
|
||||
frameDuration: frameDuration,
|
||||
cbw: vocabConfig.cbw,
|
||||
marginSeconds: 0.5,
|
||||
minSimilarity: effectiveMinSimilarity
|
||||
)
|
||||
|
||||
guard rescoreOutput.wasModified else {
|
||||
return result
|
||||
}
|
||||
|
||||
let detected = rescoreOutput.replacements.compactMap { $0.replacementWord }
|
||||
let applied = rescoreOutput.replacements.filter { $0.shouldReplace }.compactMap {
|
||||
$0.replacementWord
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Vocabulary rescoring applied \(applied.count) replacement(s)"
|
||||
)
|
||||
|
||||
return result.withRescoring(
|
||||
text: rescoreOutput.text,
|
||||
detected: detected.isEmpty ? nil : detected,
|
||||
applied: applied.isEmpty ? nil : applied
|
||||
)
|
||||
} catch {
|
||||
logger.warning("Vocabulary rescoring failed: \(error.localizedDescription)")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert CTC logits MLMultiArray to log-probabilities [[Float]] for rescoring.
|
||||
/// Applies log-softmax with temperature scaling and blank bias to match
|
||||
/// the processing done in `CtcKeywordSpotter.computeLogProbs`.
|
||||
private func convertCtcLogitsToArray(_ ctcLogits: MLMultiArray) -> [[Float]] {
|
||||
// Expected shape: [1, T, V] where T = frames, V = vocab size
|
||||
let shape = ctcLogits.shape
|
||||
guard shape.count == 3 else {
|
||||
logger.warning("Unexpected CTC logits shape: \(shape)")
|
||||
return []
|
||||
}
|
||||
|
||||
let numFrames = min(shape[1].intValue, cachedCtcValidFrames ?? shape[1].intValue)
|
||||
let vocabSize = shape[2].intValue
|
||||
|
||||
// Extract raw logits
|
||||
var rawLogits: [[Float]] = []
|
||||
rawLogits.reserveCapacity(numFrames)
|
||||
|
||||
for t in 0..<numFrames {
|
||||
var frameLogits: [Float] = []
|
||||
frameLogits.reserveCapacity(vocabSize)
|
||||
|
||||
for v in 0..<vocabSize {
|
||||
let index = [0, t, v] as [NSNumber]
|
||||
frameLogits.append(ctcLogits[index].floatValue)
|
||||
}
|
||||
|
||||
rawLogits.append(frameLogits)
|
||||
}
|
||||
|
||||
// Apply log-softmax + temperature + blank bias (same as CtcKeywordSpotter.makeLogProbs)
|
||||
return CtcKeywordSpotter.applyLogSoftmax(
|
||||
rawLogits: rawLogits,
|
||||
blankId: ContextBiasingConstants.defaultBlankId
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import CoreML
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
struct ChunkProcessor {
|
||||
let sampleSource: StreamingAudioSampleSource
|
||||
let sampleSource: AudioSampleSource
|
||||
let totalSamples: Int
|
||||
|
||||
private let logger = AppLogger(category: "ChunkProcessor")
|
||||
@@ -18,7 +16,6 @@ struct ChunkProcessor {
|
||||
// Stateless chunking aligned with CoreML reference:
|
||||
// - process ~14.96s of audio per window (frame-aligned) to stay under encoder limit
|
||||
// - 2.0s overlap (frame-aligned) to give the decoder slack when merging windows
|
||||
private let sampleRate: Int = 16000
|
||||
private let overlapSeconds: Double = 2.0
|
||||
|
||||
/// Context samples prepended from previous chunk for mel spectrogram stability (80ms = 1 encoder frame).
|
||||
@@ -36,7 +33,7 @@ struct ChunkProcessor {
|
||||
return raw / ASRConstants.samplesPerEncoderFrame * ASRConstants.samplesPerEncoderFrame
|
||||
}
|
||||
private var overlapSamples: Int {
|
||||
let requested = Int(overlapSeconds * Double(sampleRate))
|
||||
let requested = Int(overlapSeconds * Double(ASRConstants.sampleRate))
|
||||
let capped = min(requested, chunkSamples / 2)
|
||||
return capped / ASRConstants.samplesPerEncoderFrame * ASRConstants.samplesPerEncoderFrame
|
||||
}
|
||||
@@ -46,7 +43,7 @@ struct ChunkProcessor {
|
||||
}
|
||||
|
||||
/// Initialize with a streaming audio sample source for memory-efficient processing.
|
||||
init(sampleSource: StreamingAudioSampleSource) {
|
||||
init(sampleSource: AudioSampleSource) {
|
||||
self.sampleSource = sampleSource
|
||||
self.totalSamples = sampleSource.sampleCount
|
||||
}
|
||||
@@ -66,7 +63,7 @@ struct ChunkProcessor {
|
||||
var chunkStart = 0
|
||||
var chunkIndex = 0
|
||||
var chunkDecoderState = TdtDecoderState.make(
|
||||
decoderLayers: await manager.getDecoderLayers()
|
||||
decoderLayers: await manager.decoderLayerCount
|
||||
)
|
||||
|
||||
while chunkStart < totalSamples {
|
||||
@@ -219,7 +216,7 @@ struct ChunkProcessor {
|
||||
if left.isEmpty { return right }
|
||||
if right.isEmpty { return left }
|
||||
|
||||
let frameDuration = Double(ASRConstants.samplesPerEncoderFrame) / Double(sampleRate)
|
||||
let frameDuration = ASRConstants.secondsPerEncoderFrame
|
||||
let overlapDuration = overlapSeconds
|
||||
let halfOverlapWindow = overlapDuration / 2
|
||||
|
||||
@@ -433,7 +430,7 @@ struct ChunkProcessor {
|
||||
frameDuration: Double
|
||||
) -> [TokenWindow] {
|
||||
let cutoff = (leftEndTime + rightStartTime) / 2
|
||||
let trimmedLeft = left.filter { Double($0.timestamp) * frameDuration <= cutoff }
|
||||
let trimmedLeft = left.filter { Double($0.timestamp) * frameDuration < cutoff }
|
||||
let trimmedRight = right.filter { Double($0.timestamp) * frameDuration >= cutoff }
|
||||
return trimmedLeft + trimmedRight
|
||||
}
|
||||
|
||||
@@ -32,11 +32,11 @@ struct TdtDecoderState: Sendable {
|
||||
init(decoderLayers: Int = 2) throws {
|
||||
// Use ANE-aligned arrays for optimal performance
|
||||
let decoderHiddenSize = ASRConstants.decoderHiddenSize
|
||||
hiddenState = try ANEOptimizer.createANEAlignedArray(
|
||||
hiddenState = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)],
|
||||
dataType: .float32
|
||||
)
|
||||
cellState = try ANEOptimizer.createANEAlignedArray(
|
||||
cellState = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)],
|
||||
dataType: .float32
|
||||
)
|
||||
|
||||
@@ -175,11 +175,11 @@ internal struct TdtDecoderV3 {
|
||||
// Preallocate joint input tensors and a reusable provider to avoid per-step allocations.
|
||||
let encoderHidden = expectedEncoderHidden
|
||||
let decoderHidden = ASRConstants.decoderHiddenSize
|
||||
let reusableEncoderStep = try ANEOptimizer.createANEAlignedArray(
|
||||
let reusableEncoderStep = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [1, NSNumber(value: encoderHidden), 1],
|
||||
dataType: .float32
|
||||
)
|
||||
let reusableDecoderStep = try ANEOptimizer.createANEAlignedArray(
|
||||
let reusableDecoderStep = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [1, NSNumber(value: decoderHidden), 1],
|
||||
dataType: .float32
|
||||
)
|
||||
@@ -617,8 +617,8 @@ internal struct TdtDecoderV3 {
|
||||
try encoderFrames.copyFrame(at: timeIndex, into: encoderDestPtr, destinationStride: encoderDestStride)
|
||||
|
||||
// Prefetch arrays for ANE
|
||||
ANEOptimizer.prefetchToNeuralEngine(encoderStep)
|
||||
ANEOptimizer.prefetchToNeuralEngine(preparedDecoderStep)
|
||||
encoderStep.prefetchToNeuralEngine()
|
||||
preparedDecoderStep.prefetchToNeuralEngine()
|
||||
|
||||
// Reuse tiny output tensors for joint prediction (provide raw MLMultiArray backings)
|
||||
predictionOptions.outputBackings = [
|
||||
@@ -702,7 +702,7 @@ internal struct TdtDecoderV3 {
|
||||
}
|
||||
out = destination
|
||||
} else {
|
||||
out = try ANEOptimizer.createANEAlignedArray(
|
||||
out = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [1, NSNumber(value: hiddenSize), 1],
|
||||
dataType: .float32
|
||||
)
|
||||
@@ -829,7 +829,7 @@ internal struct TdtDecoderV3 {
|
||||
encoderOutput: encoderOutput,
|
||||
validLength: encoderOutput.count,
|
||||
expectedHiddenSize: config.encoderHiddenSize)
|
||||
let encoderStep = try ANEOptimizer.createANEAlignedArray(
|
||||
let encoderStep = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: [1, NSNumber(value: encoderFrames.hiddenSize), 1],
|
||||
dataType: .float32)
|
||||
let encoderPtr = encoderStep.dataPointer.bindMemory(to: Float.self, capacity: encoderFrames.hiddenSize)
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import Foundation
|
||||
import MachTaskSelfWrapper
|
||||
import os
|
||||
|
||||
/// Performance metrics for ASR processing
|
||||
public struct ASRPerformanceMetrics: Codable, Sendable {
|
||||
public let preprocessorTime: TimeInterval
|
||||
public let encoderTime: TimeInterval
|
||||
public let decoderTime: TimeInterval
|
||||
public let totalProcessingTime: TimeInterval
|
||||
public let rtfx: Float // Real-time factor
|
||||
public let peakMemoryMB: Float
|
||||
public let gpuUtilization: Float?
|
||||
|
||||
public var summary: String {
|
||||
"""
|
||||
Performance Metrics:
|
||||
- Preprocessor: \(String(format: "%.3f", preprocessorTime))s
|
||||
- Encoder: \(String(format: "%.3f", encoderTime))s
|
||||
- Decoder: \(String(format: "%.3f", decoderTime))s
|
||||
- Total: \(String(format: "%.3f", totalProcessingTime))s
|
||||
- RTFx: \(String(format: "%.1f", rtfx))x real-time
|
||||
- Peak Memory: \(String(format: "%.1f", peakMemoryMB)) MB
|
||||
- GPU Utilization: \(gpuUtilization.map { String(format: "%.1f%%", $0) } ?? "N/A")
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance monitor for tracking ASR metrics
|
||||
public actor PerformanceMonitor {
|
||||
|
||||
public init() {}
|
||||
private let logger = AppLogger(category: "Performance")
|
||||
private var metrics: [ASRPerformanceMetrics] = []
|
||||
private let signpostLogger = OSSignposter(subsystem: AppLogger.defaultSubsystem, category: "Performance")
|
||||
|
||||
/// Track performance for a processing session
|
||||
public func trackSession<T: Sendable>(
|
||||
operation: String,
|
||||
audioLengthSeconds: Float,
|
||||
block: @escaping () async throws -> T
|
||||
) async throws -> (result: T, metrics: ASRPerformanceMetrics) {
|
||||
let sessionID = signpostLogger.makeSignpostID()
|
||||
let state = signpostLogger.beginInterval("ASR.Operation", id: sessionID)
|
||||
|
||||
let startTime = Date()
|
||||
let startMemory = getCurrentMemoryUsage()
|
||||
|
||||
// Track individual components
|
||||
let preprocessorTime: TimeInterval = 0
|
||||
let encoderTime: TimeInterval = 0
|
||||
let decoderTime: TimeInterval = 0
|
||||
|
||||
// Execute the operation
|
||||
let result = try await block()
|
||||
|
||||
let totalTime = Date().timeIntervalSince(startTime)
|
||||
let peakMemory = max(startMemory, getCurrentMemoryUsage())
|
||||
let rtfx = audioLengthSeconds / Float(totalTime)
|
||||
|
||||
signpostLogger.endInterval("ASR.Operation", state)
|
||||
|
||||
let metrics = ASRPerformanceMetrics(
|
||||
preprocessorTime: preprocessorTime,
|
||||
encoderTime: encoderTime,
|
||||
decoderTime: decoderTime,
|
||||
totalProcessingTime: totalTime,
|
||||
rtfx: rtfx,
|
||||
peakMemoryMB: peakMemory,
|
||||
gpuUtilization: nil // Would require Metal performance counters
|
||||
)
|
||||
|
||||
self.metrics.append(metrics)
|
||||
logger.info("\(operation) completed: \(metrics.summary)")
|
||||
|
||||
return (result, metrics)
|
||||
}
|
||||
|
||||
/// Track a specific component's execution time
|
||||
public func trackComponent<T: Sendable>(
|
||||
_ component: String,
|
||||
block: @escaping () async throws -> T
|
||||
) async throws -> (result: T, time: TimeInterval) {
|
||||
let componentID = signpostLogger.makeSignpostID()
|
||||
let state = signpostLogger.beginInterval("ASR.Component", id: componentID)
|
||||
|
||||
let startTime = Date()
|
||||
let result = try await block()
|
||||
let time = Date().timeIntervalSince(startTime)
|
||||
|
||||
signpostLogger.endInterval("ASR.Component", state)
|
||||
|
||||
return (result, time)
|
||||
}
|
||||
|
||||
/// Get aggregated metrics
|
||||
public func getAggregatedMetrics() -> AggregatedMetrics? {
|
||||
guard !metrics.isEmpty else { return nil }
|
||||
|
||||
let avgRTFx = metrics.map { $0.rtfx }.reduce(0, +) / Float(metrics.count)
|
||||
let avgProcessingTime = metrics.map { $0.totalProcessingTime }.reduce(0, +) / Double(metrics.count)
|
||||
let maxMemory = metrics.map { $0.peakMemoryMB }.max() ?? 0
|
||||
|
||||
return AggregatedMetrics(
|
||||
averageRTFx: avgRTFx,
|
||||
averageProcessingTime: avgProcessingTime,
|
||||
maxMemoryMB: maxMemory,
|
||||
sampleCount: metrics.count
|
||||
)
|
||||
}
|
||||
|
||||
/// Clear all stored metrics
|
||||
public func reset() {
|
||||
metrics.removeAll()
|
||||
}
|
||||
|
||||
/// Get current memory usage in MB
|
||||
private func getCurrentMemoryUsage() -> Float {
|
||||
var info = mach_task_basic_info()
|
||||
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size) / 4
|
||||
|
||||
let result = withUnsafeMutablePointer(to: &info) {
|
||||
$0.withMemoryRebound(to: integer_t.self, capacity: 1) {
|
||||
task_info(
|
||||
get_current_task_port(),
|
||||
task_flavor_t(MACH_TASK_BASIC_INFO),
|
||||
$0,
|
||||
&count)
|
||||
}
|
||||
}
|
||||
|
||||
if result == KERN_SUCCESS {
|
||||
return Float(info.resident_size) / 1024.0 / 1024.0
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregated performance metrics
|
||||
public struct AggregatedMetrics: Sendable {
|
||||
public let averageRTFx: Float
|
||||
public let averageProcessingTime: TimeInterval
|
||||
public let maxMemoryMB: Float
|
||||
public let sampleCount: Int
|
||||
|
||||
public var summary: String {
|
||||
"""
|
||||
Aggregated Metrics (\(sampleCount) samples):
|
||||
- Average RTFx: \(String(format: "%.1f", averageRTFx))x real-time
|
||||
- Average Processing Time: \(String(format: "%.3f", averageProcessingTime))s
|
||||
- Max Memory Usage: \(String(format: "%.1f", maxMemoryMB)) MB
|
||||
"""
|
||||
}
|
||||
}
|
||||
+1
-4
@@ -248,10 +248,7 @@ extension CtcModels {
|
||||
|
||||
/// Default CoreML configuration for CTC inference.
|
||||
public static func defaultConfiguration() -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
config.allowLowPrecisionAccumulationOnGPU = true
|
||||
config.computeUnits = .cpuAndNeuralEngine
|
||||
return config
|
||||
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
/// Check whether required CTC model bundles and vocabulary exist at a directory.
|
||||
|
||||
@@ -48,11 +48,9 @@ public actor SlidingWindowAsrManager {
|
||||
|
||||
// Vocabulary boosting
|
||||
// These are initialized via configureVocabularyBoosting() before start()
|
||||
// CtcKeywordSpotter and VocabularyRescorer contain CoreML models which are not Sendable.
|
||||
// We manage the safety ourselves by only accessing them from within the actor.
|
||||
private var customVocabulary: CustomVocabularyContext?
|
||||
nonisolated(unsafe) private var ctcSpotter: CtcKeywordSpotter?
|
||||
nonisolated(unsafe) private var vocabularyRescorer: VocabularyRescorer?
|
||||
private var ctcSpotter: CtcKeywordSpotter?
|
||||
private var vocabularyRescorer: VocabularyRescorer?
|
||||
private var vocabSizeConfig: ContextBiasingConstants.VocabSizeConfig?
|
||||
private var vocabBoostingEnabled: Bool { customVocabulary != nil && vocabularyRescorer != nil }
|
||||
|
||||
@@ -376,7 +374,7 @@ public actor SlidingWindowAsrManager {
|
||||
// Start frame offset is now handled by decoder's timeJump mechanism
|
||||
|
||||
// Call AsrManager directly with deduplication
|
||||
let (tokens, timestamps, confidences, _) = try await asrManager.transcribeStreamingChunk(
|
||||
let (tokens, timestamps, confidences, _) = try await asrManager.transcribeChunk(
|
||||
windowSamples,
|
||||
source: audioSource,
|
||||
previousTokens: accumulatedTokens,
|
||||
|
||||
@@ -89,7 +89,10 @@ public final class RnntDecoder {
|
||||
let decoderInput = try prepareDecoderInput(lastToken: lastToken, h: hState, c: cState)
|
||||
let decoderOutput = try decoderModel.prediction(from: decoderInput)
|
||||
|
||||
var decoderStep = decoderOutput.featureValue(for: "decoder")!.multiArrayValue!
|
||||
guard let decoderArray = decoderOutput.featureValue(for: "decoder")?.multiArrayValue else {
|
||||
throw RnntDecoderError.missingOutput("decoder")
|
||||
}
|
||||
var decoderStep = decoderArray
|
||||
// Decoder outputs [1, 640, 2] - NeMo uses the LAST frame
|
||||
if decoderStep.shape.count == 3 && decoderStep.shape[2].intValue > 1 {
|
||||
// Slice to keep only the last frame [1, 640, 1]
|
||||
@@ -106,7 +109,9 @@ public final class RnntDecoder {
|
||||
|
||||
// 3. Get Token ID
|
||||
// Output "token_id" is [1, 1, 1] (argmax)
|
||||
let tokenIdMultiArray = jointOutput.featureValue(for: "token_id")!.multiArrayValue!
|
||||
guard let tokenIdMultiArray = jointOutput.featureValue(for: "token_id")?.multiArrayValue else {
|
||||
throw RnntDecoderError.missingOutput("token_id")
|
||||
}
|
||||
let tokenId = tokenIdMultiArray[0].int32Value
|
||||
|
||||
if tokenId == blankId {
|
||||
@@ -120,8 +125,12 @@ public final class RnntDecoder {
|
||||
lastToken = tokenId
|
||||
|
||||
// Update State
|
||||
let newH = decoderOutput.featureValue(for: "h_out")!.multiArrayValue!
|
||||
let newC = decoderOutput.featureValue(for: "c_out")!.multiArrayValue!
|
||||
guard let newH = decoderOutput.featureValue(for: "h_out")?.multiArrayValue else {
|
||||
throw RnntDecoderError.missingOutput("h_out")
|
||||
}
|
||||
guard let newC = decoderOutput.featureValue(for: "c_out")?.multiArrayValue else {
|
||||
throw RnntDecoderError.missingOutput("c_out")
|
||||
}
|
||||
|
||||
hState = newH
|
||||
cState = newC
|
||||
@@ -222,3 +231,14 @@ public final class RnntDecoder {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
enum RnntDecoderError: Error, LocalizedError {
|
||||
case missingOutput(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingOutput(let name):
|
||||
return "RNNT decoder missing expected output: \(name)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,21 +98,12 @@ extension DiarizerModels {
|
||||
}
|
||||
|
||||
public static func defaultModelsDirectory() -> URL {
|
||||
let applicationSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
|
||||
return
|
||||
applicationSupport
|
||||
.appendingPathComponent("FluidAudio", isDirectory: true)
|
||||
.appendingPathComponent("Models", isDirectory: true)
|
||||
.appendingPathComponent(Repo.diarizer.folderName, isDirectory: true)
|
||||
MLModelConfigurationUtils.defaultModelsDirectory(for: .diarizer)
|
||||
}
|
||||
|
||||
static func defaultConfiguration() -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
// Enable Float16 optimization for ~2x speedup
|
||||
config.allowLowPrecisionAccumulationOnGPU = true
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
config.computeUnits = isCI ? .cpuAndNeuralEngine : .all
|
||||
return config
|
||||
return MLModelConfigurationUtils.defaultConfiguration(computeUnits: isCI ? .cpuAndNeuralEngine : .all)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ public final class OfflineDiarizerManager {
|
||||
/// - Parameter url: Path to the audio file
|
||||
/// - Returns: Diarization result with speaker segments
|
||||
public func process(_ url: URL) async throws -> DiarizationResult {
|
||||
let factory = StreamingAudioSourceFactory()
|
||||
let factory = AudioSourceFactory()
|
||||
let (source, loadDuration) = try factory.makeDiskBackedSource(
|
||||
from: url,
|
||||
targetSampleRate: config.segmentation.sampleRate
|
||||
@@ -114,7 +114,7 @@ public final class OfflineDiarizerManager {
|
||||
}
|
||||
|
||||
public func process(
|
||||
audioSource: StreamingAudioSampleSource,
|
||||
audioSource: AudioSampleSource,
|
||||
audioLoadingSeconds: TimeInterval
|
||||
) async throws -> DiarizationResult {
|
||||
try config.validate()
|
||||
|
||||
@@ -68,18 +68,11 @@ public struct OfflineDiarizerModels: Sendable {
|
||||
}
|
||||
|
||||
public static func defaultModelsDirectory() -> URL {
|
||||
let base = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
|
||||
return
|
||||
base
|
||||
.appendingPathComponent("FluidAudio", isDirectory: true)
|
||||
.appendingPathComponent("Models", isDirectory: true)
|
||||
MLModelConfigurationUtils.defaultModelsDirectory()
|
||||
}
|
||||
|
||||
private static func defaultConfiguration() -> MLModelConfiguration {
|
||||
let configuration = MLModelConfiguration()
|
||||
configuration.allowLowPrecisionAccumulationOnGPU = true
|
||||
configuration.computeUnits = .all
|
||||
return configuration
|
||||
MLModelConfigurationUtils.defaultConfiguration(computeUnits: .all)
|
||||
}
|
||||
|
||||
public static func load(
|
||||
|
||||
@@ -179,7 +179,7 @@ struct OfflineEmbeddingExtractor {
|
||||
}
|
||||
|
||||
func extractEmbeddings(
|
||||
audioSource: StreamingAudioSampleSource,
|
||||
audioSource: AudioSampleSource,
|
||||
segmentation: SegmentationOutput
|
||||
) async throws -> [TimedEmbedding] {
|
||||
let stream = AsyncThrowingStream<SegmentationChunk, Error> { continuation in
|
||||
@@ -221,7 +221,7 @@ struct OfflineEmbeddingExtractor {
|
||||
}
|
||||
|
||||
func extractEmbeddings<S: AsyncSequence>(
|
||||
audioSource: StreamingAudioSampleSource,
|
||||
audioSource: AudioSampleSource,
|
||||
segmentationStream: S
|
||||
) async throws -> [TimedEmbedding] where S.Element == SegmentationChunk {
|
||||
var embeddings: [TimedEmbedding] = []
|
||||
|
||||
@@ -42,7 +42,7 @@ struct OfflineSegmentationProcessor {
|
||||
}
|
||||
|
||||
func process(
|
||||
audioSource: StreamingAudioSampleSource,
|
||||
audioSource: AudioSampleSource,
|
||||
segmentationModel: MLModel,
|
||||
config: OfflineDiarizerConfig,
|
||||
chunkHandler: SegmentationChunkHandler? = nil
|
||||
|
||||
@@ -91,11 +91,8 @@ extension SortformerModels {
|
||||
|
||||
/// Default MLModel configuration
|
||||
public static func defaultConfiguration() -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
config.allowLowPrecisionAccumulationOnGPU = true
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
config.computeUnits = isCI ? .cpuAndNeuralEngine : .all
|
||||
return config
|
||||
return MLModelConfigurationUtils.defaultConfiguration(computeUnits: isCI ? .cpuAndNeuralEngine : .all)
|
||||
}
|
||||
|
||||
/// Load Sortformer models from HuggingFace.
|
||||
|
||||
@@ -129,6 +129,12 @@ public enum Repo: String, CaseIterable {
|
||||
return "nemotron-streaming/560ms"
|
||||
case .sortformer:
|
||||
return "sortformer"
|
||||
case .parakeetCtc110m:
|
||||
return "parakeet-ctc-110m-coreml"
|
||||
case .parakeetCtc06b:
|
||||
return "parakeet-ctc-0.6b-coreml"
|
||||
case .parakeetTdtCtc110m:
|
||||
return "parakeet-tdt-ctc-110m"
|
||||
default:
|
||||
return name.replacingOccurrences(of: "-coreml", with: "")
|
||||
}
|
||||
@@ -203,9 +209,6 @@ public enum ModelNames {
|
||||
jointFile,
|
||||
]
|
||||
|
||||
/// Vocabulary filename for the 110m hybrid TDT-CTC model (JSON array format)
|
||||
public static let vocabularyFileArray = "parakeet_vocab.json"
|
||||
|
||||
/// Required models for fused frontend (110m hybrid: preprocessor contains encoder)
|
||||
public static let requiredModelsFused: Set<String> = [
|
||||
preprocessorFile,
|
||||
@@ -215,12 +218,8 @@ public enum ModelNames {
|
||||
|
||||
/// Get vocabulary filename for specific model version
|
||||
public static func vocabulary(for repo: Repo) -> String {
|
||||
switch repo {
|
||||
case .parakeetTdtCtc110m:
|
||||
return vocabularyFileArray
|
||||
default:
|
||||
return vocabularyFile
|
||||
}
|
||||
// All Parakeet models use the same vocabulary file (format varies: dict for v2/v3, array for 110m)
|
||||
return vocabularyFile
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import Accelerate
|
||||
import CoreML
|
||||
import Darwin
|
||||
import Foundation
|
||||
@@ -143,6 +144,41 @@ public enum ANEMemoryUtils {
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert a float32 MLMultiArray to float16 with ANE-aligned memory.
|
||||
public static func convertToFloat16(_ input: MLMultiArray) throws -> MLMultiArray {
|
||||
guard input.dataType == .float32 else {
|
||||
throw ANEMemoryError.unsupportedDataType
|
||||
}
|
||||
|
||||
let float16Array = try createAlignedArray(
|
||||
shape: input.shape,
|
||||
dataType: .float16,
|
||||
zeroClear: false
|
||||
)
|
||||
|
||||
let sourcePtr = input.dataPointer.bindMemory(to: Float.self, capacity: input.count)
|
||||
|
||||
var sourceBuffer = vImage_Buffer(
|
||||
data: sourcePtr,
|
||||
height: 1,
|
||||
width: vImagePixelCount(input.count),
|
||||
rowBytes: input.count * MemoryLayout<Float>.stride
|
||||
)
|
||||
|
||||
let destPtr = float16Array.dataPointer.bindMemory(to: UInt16.self, capacity: input.count)
|
||||
|
||||
var destBuffer = vImage_Buffer(
|
||||
data: destPtr,
|
||||
height: 1,
|
||||
width: vImagePixelCount(input.count),
|
||||
rowBytes: input.count * MemoryLayout<UInt16>.stride
|
||||
)
|
||||
|
||||
vImageConvert_PlanarFtoPlanar16F(&sourceBuffer, &destBuffer, 0)
|
||||
|
||||
return float16Array
|
||||
}
|
||||
|
||||
/// Stride-aware copy between two MLMultiArrays that may have different stride layouts.
|
||||
///
|
||||
/// Copies all logical elements from `source` to `destination` (which must have the same shape
|
||||
|
||||
@@ -27,6 +27,9 @@ public enum ASRConstants {
|
||||
/// Each encoder frame represents ~80ms of audio at 16kHz
|
||||
public static let samplesPerEncoderFrame: Int = melHopSize * encoderSubsampling // 1280
|
||||
|
||||
/// Duration of one encoder frame in seconds (80ms)
|
||||
public static let secondsPerEncoderFrame: Double = Double(samplesPerEncoderFrame) / Double(sampleRate) // 0.08
|
||||
|
||||
/// WER threshold for detailed error analysis in benchmarks
|
||||
public static let highWERThreshold: Double = 0.15
|
||||
|
||||
|
||||
+3
-3
@@ -1,6 +1,6 @@
|
||||
import Foundation
|
||||
|
||||
public protocol StreamingAudioSampleSource: Sendable {
|
||||
public protocol AudioSampleSource: Sendable {
|
||||
var sampleCount: Int { get }
|
||||
func copySamples(
|
||||
into destination: UnsafeMutablePointer<Float>,
|
||||
@@ -9,7 +9,7 @@ public protocol StreamingAudioSampleSource: Sendable {
|
||||
) throws
|
||||
}
|
||||
|
||||
public struct ArrayAudioSampleSource: StreamingAudioSampleSource {
|
||||
public struct ArrayAudioSampleSource: AudioSampleSource {
|
||||
private let samples: [Float]
|
||||
|
||||
public init(samples: [Float]) {
|
||||
@@ -39,7 +39,7 @@ public struct ArrayAudioSampleSource: StreamingAudioSampleSource {
|
||||
}
|
||||
}
|
||||
|
||||
public struct DiskBackedAudioSampleSource: StreamingAudioSampleSource {
|
||||
public struct DiskBackedAudioSampleSource: AudioSampleSource {
|
||||
private let mappedData: Data
|
||||
private let floatStride = MemoryLayout<Float>.stride
|
||||
private let fileURL: URL
|
||||
@@ -0,0 +1,6 @@
|
||||
import Foundation
|
||||
|
||||
public enum AudioSource: Sendable {
|
||||
case microphone
|
||||
case system
|
||||
}
|
||||
+14
-14
@@ -3,8 +3,8 @@ import Foundation
|
||||
import OSLog
|
||||
import os
|
||||
|
||||
public struct StreamingAudioSourceFactory {
|
||||
private let logger = AppLogger(category: "StreamingAudioSourceFactory")
|
||||
public struct AudioSourceFactory {
|
||||
private let logger = AppLogger(category: "AudioSourceFactory")
|
||||
|
||||
public init() {}
|
||||
|
||||
@@ -26,7 +26,7 @@ public struct StreamingAudioSourceFactory {
|
||||
|
||||
let tempURL = try makeTemporaryURL()
|
||||
guard FileManager.default.createFile(atPath: tempURL.path, contents: nil) else {
|
||||
throw StreamingAudioError.processingFailed("Failed to create temporary audio buffer at \(tempURL.path)")
|
||||
throw AudioSourceError.processingFailed("Failed to create temporary audio buffer at \(tempURL.path)")
|
||||
}
|
||||
|
||||
let handle = try FileHandle(forWritingTo: tempURL)
|
||||
@@ -35,7 +35,7 @@ public struct StreamingAudioSourceFactory {
|
||||
}
|
||||
|
||||
guard let converter = AVAudioConverter(from: inputFormat, to: targetFormat) else {
|
||||
throw StreamingAudioError.processingFailed(
|
||||
throw AudioSourceError.processingFailed(
|
||||
"Unsupported audio format \(inputFormat); failed to create converter")
|
||||
}
|
||||
|
||||
@@ -75,11 +75,11 @@ public struct StreamingAudioSourceFactory {
|
||||
|
||||
let duration = Date().timeIntervalSince(startTime)
|
||||
return (source, duration)
|
||||
} catch let streamingError as StreamingAudioError {
|
||||
} catch let streamingError as AudioSourceError {
|
||||
throw streamingError
|
||||
} catch {
|
||||
logger.error("Streaming audio source creation failed: \(error.localizedDescription)")
|
||||
throw StreamingAudioError.processingFailed(
|
||||
throw AudioSourceError.processingFailed(
|
||||
"Streaming audio source creation failed: \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
@@ -106,7 +106,7 @@ public struct StreamingAudioSourceFactory {
|
||||
frameCapacity: inputCapacity
|
||||
)
|
||||
else {
|
||||
throw StreamingAudioError.failedToAllocateBuffer("Input", requestedFrames: Int(inputCapacity))
|
||||
throw AudioSourceError.failedToAllocateBuffer("Input", requestedFrames: Int(inputCapacity))
|
||||
}
|
||||
|
||||
let estimatedOutputFrames = AVAudioFrameCount(
|
||||
@@ -118,7 +118,7 @@ public struct StreamingAudioSourceFactory {
|
||||
frameCapacity: max(1024, estimatedOutputFrames)
|
||||
)
|
||||
else {
|
||||
throw StreamingAudioError.failedToAllocateBuffer("Output", requestedFrames: Int(estimatedOutputFrames))
|
||||
throw AudioSourceError.failedToAllocateBuffer("Output", requestedFrames: Int(estimatedOutputFrames))
|
||||
}
|
||||
|
||||
var totalSamples = 0
|
||||
@@ -167,13 +167,13 @@ public struct StreamingAudioSourceFactory {
|
||||
)
|
||||
|
||||
if let conversionError {
|
||||
throw StreamingAudioError.processingFailed(
|
||||
throw AudioSourceError.processingFailed(
|
||||
"Audio conversion failed: \(conversionError.localizedDescription)"
|
||||
)
|
||||
}
|
||||
|
||||
if let error = readError.withLock({ $0 }) {
|
||||
throw StreamingAudioError.processingFailed(
|
||||
throw AudioSourceError.processingFailed(
|
||||
"Failed while reading audio: \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
@@ -181,7 +181,7 @@ public struct StreamingAudioSourceFactory {
|
||||
let producedFrames = Int(outputBuffer.frameLength)
|
||||
if producedFrames > 0 {
|
||||
guard let channelData = outputBuffer.floatChannelData?.pointee else {
|
||||
throw StreamingAudioError.processingFailed("Missing channel data during conversion")
|
||||
throw AudioSourceError.processingFailed("Missing channel data during conversion")
|
||||
}
|
||||
let byteCount = producedFrames * MemoryLayout<Float>.stride
|
||||
let baseAddress = UnsafeRawPointer(channelData)
|
||||
@@ -199,7 +199,7 @@ public struct StreamingAudioSourceFactory {
|
||||
}
|
||||
}
|
||||
|
||||
public enum StreamingAudioError: Error, LocalizedError {
|
||||
public enum AudioSourceError: Error, LocalizedError {
|
||||
case processingFailed(String)
|
||||
|
||||
public var errorDescription: String? {
|
||||
@@ -210,8 +210,8 @@ public enum StreamingAudioError: Error, LocalizedError {
|
||||
}
|
||||
}
|
||||
|
||||
extension StreamingAudioError {
|
||||
fileprivate static func failedToAllocateBuffer(_ name: String, requestedFrames: Int) -> StreamingAudioError {
|
||||
extension AudioSourceError {
|
||||
fileprivate static func failedToAllocateBuffer(_ name: String, requestedFrames: Int) -> AudioSourceError {
|
||||
.processingFailed("Failed to allocate \(name.lowercased()) buffer (\(requestedFrames) frames)")
|
||||
}
|
||||
}
|
||||
+5
-28
@@ -1,12 +1,10 @@
|
||||
import CoreML
|
||||
import Foundation
|
||||
import os
|
||||
|
||||
/// Thread-safe cache for MLMultiArray instances to reduce allocation overhead
|
||||
actor MLArrayCache {
|
||||
private var cache: [CacheKey: [MLMultiArray]] = [:]
|
||||
private let maxCacheSize: Int
|
||||
private let logger = AppLogger(category: "MLArrayCache")
|
||||
|
||||
struct CacheKey: Hashable {
|
||||
let shape: [Int]
|
||||
@@ -24,16 +22,13 @@ actor MLArrayCache {
|
||||
dataType: dataType
|
||||
)
|
||||
|
||||
// Check if we have a cached array
|
||||
if var arrays = cache[key], !arrays.isEmpty {
|
||||
// Never return the same buffer twice while it is still in use; keep the trimmed bucket so we only
|
||||
// hand out arrays that callers have explicitly returned to the cache.
|
||||
let array = arrays.removeLast()
|
||||
cache[key] = arrays
|
||||
return array
|
||||
}
|
||||
|
||||
return try ANEOptimizer.createANEAlignedArray(shape: shape, dataType: dataType)
|
||||
return try ANEMemoryUtils.createAlignedArray(shape: shape, dataType: dataType)
|
||||
}
|
||||
|
||||
/// Return an array to the cache for reuse
|
||||
@@ -47,53 +42,35 @@ actor MLArrayCache {
|
||||
|
||||
// Limit cache size per key
|
||||
if arrays.count < maxCacheSize / max(cache.count, 1) {
|
||||
// Reset the array data before caching
|
||||
if array.dataType == .float32 {
|
||||
array.resetData(to: 0)
|
||||
}
|
||||
array.resetData(to: 0)
|
||||
arrays.append(array)
|
||||
cache[key] = arrays
|
||||
logger.debug("Returned array to cache for shape: \(array.shape)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-warm the cache with commonly used shapes
|
||||
func prewarm(shapes: [(shape: [NSNumber], dataType: MLMultiArrayDataType)]) async {
|
||||
logger.info("Pre-warming cache with \(shapes.count) shapes")
|
||||
|
||||
func prewarm(shapes: [(shape: [NSNumber], dataType: MLMultiArrayDataType)]) {
|
||||
for (shape, dataType) in shapes {
|
||||
do {
|
||||
var arrays: [MLMultiArray] = []
|
||||
let prewarmCount = min(5, maxCacheSize / max(shapes.count, 1))
|
||||
|
||||
for _ in 0..<prewarmCount {
|
||||
let array = try ANEOptimizer.createANEAlignedArray(shape: shape, dataType: dataType)
|
||||
let array = try ANEMemoryUtils.createAlignedArray(shape: shape, dataType: dataType)
|
||||
arrays.append(array)
|
||||
}
|
||||
|
||||
let key = CacheKey(shape: shape.map { $0.intValue }, dataType: dataType)
|
||||
cache[key] = arrays
|
||||
} catch {
|
||||
logger.error("Failed to pre-warm shape \(shape): \(error)")
|
||||
// Silently skip shapes that fail to allocate during pre-warm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a Float16 array (converting from Float32 if needed)
|
||||
func getFloat16Array(shape: [NSNumber], from float32Array: MLMultiArray? = nil) throws -> MLMultiArray {
|
||||
if let float32Array = float32Array {
|
||||
// Convert existing array to Float16
|
||||
return try ANEOptimizer.convertToFloat16(float32Array)
|
||||
} else {
|
||||
// Get new Float16 array from cache
|
||||
return try getArray(shape: shape, dataType: .float16)
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the cache
|
||||
func clear() {
|
||||
cache.removeAll()
|
||||
logger.info("Cache cleared")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
@preconcurrency import CoreML
|
||||
import Foundation
|
||||
|
||||
/// Shared utilities for creating `MLModelConfiguration` instances and resolving model directories.
|
||||
public enum MLModelConfigurationUtils {
|
||||
|
||||
/// Create a default `MLModelConfiguration` with low-precision GPU accumulation enabled.
|
||||
///
|
||||
/// - Parameter computeUnits: Compute units to use (default: `.cpuAndNeuralEngine`).
|
||||
/// - Returns: Configured `MLModelConfiguration`.
|
||||
public static func defaultConfiguration(
|
||||
computeUnits: MLComputeUnits = .cpuAndNeuralEngine
|
||||
) -> MLModelConfiguration {
|
||||
let config = MLModelConfiguration()
|
||||
config.allowLowPrecisionAccumulationOnGPU = true
|
||||
config.computeUnits = computeUnits
|
||||
return config
|
||||
}
|
||||
|
||||
/// Default models directory under Application Support.
|
||||
///
|
||||
/// - Parameter repo: Optional repository whose `folderName` is appended. When `nil`,
|
||||
/// returns `~/Library/Application Support/FluidAudio/Models/`.
|
||||
/// - Returns: URL for the models directory.
|
||||
public static func defaultModelsDirectory(for repo: Repo? = nil) -> URL {
|
||||
let base = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
|
||||
var url =
|
||||
base
|
||||
.appendingPathComponent("FluidAudio", isDirectory: true)
|
||||
.appendingPathComponent("Models", isDirectory: true)
|
||||
if let repo {
|
||||
url = url.appendingPathComponent(repo.folderName, isDirectory: true)
|
||||
}
|
||||
return url
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
import Foundation
|
||||
|
||||
/// Performance metrics for ASR processing
|
||||
public struct ASRPerformanceMetrics: Codable, Sendable {
|
||||
public let preprocessorTime: TimeInterval
|
||||
public let encoderTime: TimeInterval
|
||||
public let decoderTime: TimeInterval
|
||||
public let totalProcessingTime: TimeInterval
|
||||
public let rtfx: Float // Real-time factor
|
||||
public let peakMemoryMB: Float
|
||||
public let gpuUtilization: Float?
|
||||
|
||||
public var summary: String {
|
||||
"""
|
||||
Performance Metrics:
|
||||
- Preprocessor: \(String(format: "%.3f", preprocessorTime))s
|
||||
- Encoder: \(String(format: "%.3f", encoderTime))s
|
||||
- Decoder: \(String(format: "%.3f", decoderTime))s
|
||||
- Total: \(String(format: "%.3f", totalProcessingTime))s
|
||||
- RTFx: \(String(format: "%.1f", rtfx))x real-time
|
||||
- Peak Memory: \(String(format: "%.1f", peakMemoryMB)) MB
|
||||
- GPU Utilization: \(gpuUtilization.map { String(format: "%.1f%%", $0) } ?? "N/A")
|
||||
"""
|
||||
}
|
||||
}
|
||||
+10
-24
@@ -7,38 +7,33 @@ actor ProgressEmitter {
|
||||
|
||||
init() {}
|
||||
|
||||
func ensureSession() async -> AsyncThrowingStream<Double, Error> {
|
||||
func ensureSession() -> AsyncThrowingStream<Double, Error> {
|
||||
if let stream = streamStorage {
|
||||
return stream
|
||||
}
|
||||
return await startSession()
|
||||
return startSession()
|
||||
}
|
||||
|
||||
func currentStream() async -> AsyncThrowingStream<Double, Error> {
|
||||
await ensureSession()
|
||||
}
|
||||
|
||||
func report(progress: Double) async {
|
||||
func report(progress: Double) {
|
||||
guard isActive else { return }
|
||||
let clamped = min(max(progress, 0.0), 1.0)
|
||||
continuation?.yield(clamped)
|
||||
}
|
||||
|
||||
func finishSession() async {
|
||||
guard isActive else {
|
||||
_ = await ensureSession()
|
||||
return
|
||||
}
|
||||
func finishSession() {
|
||||
guard isActive else { return }
|
||||
|
||||
continuation?.yield(1.0)
|
||||
continuation?.finish()
|
||||
reset()
|
||||
}
|
||||
|
||||
func failSession(_ error: Error) async {
|
||||
func failSession(_ error: Error) {
|
||||
continuation?.finish(throwing: error)
|
||||
reset()
|
||||
}
|
||||
|
||||
private func startSession() async -> AsyncThrowingStream<Double, Error> {
|
||||
private func startSession() -> AsyncThrowingStream<Double, Error> {
|
||||
if let stream = streamStorage {
|
||||
return stream
|
||||
}
|
||||
@@ -48,22 +43,13 @@ actor ProgressEmitter {
|
||||
self.continuation = continuation
|
||||
self.isActive = true
|
||||
|
||||
continuation.onTermination =
|
||||
{ [weak self] (_: AsyncThrowingStream<Double, Error>.Continuation.Termination) in
|
||||
Task { [weak self] in
|
||||
guard let self else { return }
|
||||
await self.resetAndPrepareNextSession()
|
||||
}
|
||||
}
|
||||
|
||||
continuation.yield(0.0)
|
||||
return stream
|
||||
}
|
||||
|
||||
private func resetAndPrepareNextSession() async {
|
||||
private func reset() {
|
||||
continuation = nil
|
||||
streamStorage = nil
|
||||
isActive = false
|
||||
_ = await startSession()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import CoreML
|
||||
import Foundation
|
||||
|
||||
/// Zero-copy MLFeatureProvider for chaining model outputs to inputs.
|
||||
public class ZeroCopyFeatureProvider: NSObject, MLFeatureProvider {
|
||||
private let features: [String: MLFeatureValue]
|
||||
|
||||
public init(features: [String: MLFeatureValue]) {
|
||||
self.features = features
|
||||
super.init()
|
||||
}
|
||||
|
||||
public var featureNames: Set<String> {
|
||||
Set(features.keys)
|
||||
}
|
||||
|
||||
public func featureValue(for featureName: String) -> MLFeatureValue? {
|
||||
features[featureName]
|
||||
}
|
||||
|
||||
/// Create a provider that chains output from one model to input of another
|
||||
public static func chain(
|
||||
from outputProvider: MLFeatureProvider,
|
||||
outputName: String,
|
||||
to inputName: String
|
||||
) -> ZeroCopyFeatureProvider? {
|
||||
guard let outputValue = outputProvider.featureValue(for: outputName) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ZeroCopyFeatureProvider(features: [inputName: outputValue])
|
||||
}
|
||||
}
|
||||
@@ -166,7 +166,7 @@ enum ProcessCommand {
|
||||
|
||||
// Load and process audio file without materializing the full sample buffer.
|
||||
let audioURL = URL(fileURLWithPath: audioFile)
|
||||
let factory = StreamingAudioSourceFactory()
|
||||
let factory = AudioSourceFactory()
|
||||
let targetSampleRate = offlineConfig.segmentation.sampleRate
|
||||
let diskSourceResult = try factory.makeDiskBackedSource(
|
||||
from: audioURL,
|
||||
|
||||
@@ -90,54 +90,6 @@ final class AsrManagerExtensionTests: XCTestCase {
|
||||
XCTAssertEqual(Array(result.suffix(500)), Array(repeating: 0.0, count: 500))
|
||||
}
|
||||
|
||||
// MARK: - calculateStartFrameOffset Tests
|
||||
|
||||
func testCalculateStartFrameOffsetFirstSegment() {
|
||||
let offset = manager.calculateStartFrameOffset(segmentIndex: 0, leftContextSeconds: 2.0)
|
||||
|
||||
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
|
||||
XCTAssertEqual(offset, 0)
|
||||
}
|
||||
|
||||
func testCalculateStartFrameOffsetSecondSegment() {
|
||||
let leftContext = 2.0
|
||||
let offset = manager.calculateStartFrameOffset(segmentIndex: 1, leftContextSeconds: leftContext)
|
||||
|
||||
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
|
||||
XCTAssertEqual(offset, 0)
|
||||
}
|
||||
|
||||
func testCalculateStartFrameOffsetThirdSegment() {
|
||||
let leftContext = 1.5
|
||||
let offset = manager.calculateStartFrameOffset(segmentIndex: 2, leftContextSeconds: leftContext)
|
||||
|
||||
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
|
||||
XCTAssertEqual(offset, 0)
|
||||
}
|
||||
|
||||
func testCalculateStartFrameOffsetVariousContexts() {
|
||||
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
|
||||
let testCases: [(leftContext: Double, expected: Int)] = [
|
||||
(0.0, 0), // No context
|
||||
(0.08, 0), // Method always returns 0
|
||||
(0.16, 0), // Method always returns 0
|
||||
(1.0, 0), // Method always returns 0
|
||||
(3.2, 0), // Method always returns 0
|
||||
]
|
||||
|
||||
for (leftContext, expected) in testCases {
|
||||
let offset = manager.calculateStartFrameOffset(segmentIndex: 1, leftContextSeconds: leftContext)
|
||||
XCTAssertEqual(offset, expected, "Failed for leftContext=\(leftContext)")
|
||||
}
|
||||
}
|
||||
|
||||
func testCalculateStartFrameOffsetNegativeSegment() {
|
||||
let offset = manager.calculateStartFrameOffset(segmentIndex: -1, leftContextSeconds: 2.0)
|
||||
|
||||
// Method is deprecated - now always returns 0 (frame tracking handled by timeJump mechanism)
|
||||
XCTAssertEqual(offset, 0)
|
||||
}
|
||||
|
||||
// MARK: - Performance Tests
|
||||
|
||||
func testPadAudioPerformance() {
|
||||
@@ -151,11 +103,4 @@ final class AsrManagerExtensionTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testCalculateStartFrameOffsetPerformance() {
|
||||
measure {
|
||||
for i in 0..<10_000 {
|
||||
_ = manager.calculateStartFrameOffset(segmentIndex: i % 100, leftContextSeconds: 2.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,34 +209,13 @@ final class AsrModelsTests: XCTestCase {
|
||||
// In CI environment, all compute units are overridden to .cpuOnly
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
|
||||
// Test encoder configuration
|
||||
let melConfig = AsrModels.optimizedConfiguration(for: .encoder)
|
||||
let config = AsrModels.optimizedConfiguration()
|
||||
if isCI {
|
||||
XCTAssertEqual(melConfig.computeUnits, .cpuOnly)
|
||||
XCTAssertEqual(config.computeUnits, .cpuOnly)
|
||||
} else {
|
||||
XCTAssertEqual(melConfig.computeUnits, .cpuAndNeuralEngine)
|
||||
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
XCTAssertTrue(melConfig.allowLowPrecisionAccumulationOnGPU)
|
||||
|
||||
// Test decoder configuration
|
||||
let decoderConfig = AsrModels.optimizedConfiguration(for: .decoder)
|
||||
if isCI {
|
||||
XCTAssertEqual(decoderConfig.computeUnits, .cpuOnly)
|
||||
} else {
|
||||
XCTAssertEqual(decoderConfig.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
// Test joint configuration
|
||||
let jointConfig = AsrModels.optimizedConfiguration(for: .joint)
|
||||
if isCI {
|
||||
XCTAssertEqual(jointConfig.computeUnits, .cpuOnly)
|
||||
} else {
|
||||
XCTAssertEqual(jointConfig.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
// Test with FP16 disabled
|
||||
let fp32Config = AsrModels.optimizedConfiguration(for: .encoder, enableFP16: false)
|
||||
XCTAssertFalse(fp32Config.allowLowPrecisionAccumulationOnGPU)
|
||||
XCTAssertTrue(config.allowLowPrecisionAccumulationOnGPU)
|
||||
}
|
||||
|
||||
func testOptimizedConfigurationCIEnvironment() {
|
||||
@@ -251,7 +230,7 @@ final class AsrModelsTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
let config = AsrModels.optimizedConfiguration(for: .encoder)
|
||||
let config = AsrModels.optimizedConfiguration()
|
||||
XCTAssertEqual(config.computeUnits, .cpuOnly)
|
||||
}
|
||||
|
||||
@@ -288,22 +267,10 @@ final class AsrModelsTests: XCTestCase {
|
||||
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
func testOptimalComputeUnitsRespectsPlatform() {
|
||||
// Test each model type
|
||||
let modelTypes: [ANEOptimizer.ModelType] = [
|
||||
.encoder,
|
||||
.decoder,
|
||||
.joint,
|
||||
]
|
||||
|
||||
for modelType in modelTypes {
|
||||
let computeUnits = ANEOptimizer.optimalComputeUnits(for: modelType)
|
||||
|
||||
// All models should use CPU+ANE for optimal performance
|
||||
XCTAssertEqual(
|
||||
computeUnits, .cpuAndNeuralEngine,
|
||||
"Model type \(modelType) should use CPU+ANE")
|
||||
}
|
||||
func testOptimalComputeUnitsDefault() {
|
||||
// Default configuration uses CPU+ANE for optimal performance
|
||||
let config = AsrModels.defaultConfiguration()
|
||||
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
// MARK: - TDT-CTC-110M Model Version Tests
|
||||
@@ -384,8 +351,8 @@ final class AsrModelsTests: XCTestCase {
|
||||
}
|
||||
|
||||
func testTdtCtc110mVocabularyFilename() {
|
||||
// tdtCtc110m uses parakeet_vocab.json (array format)
|
||||
let vocabFile = ModelNames.ASR.vocabularyFileArray
|
||||
// tdtCtc110m uses parakeet_vocab.json (array format, same filename as v2/v3)
|
||||
let vocabFile = ModelNames.ASR.vocabularyFile
|
||||
XCTAssertEqual(vocabFile, "parakeet_vocab.json")
|
||||
|
||||
// Verify it has .json extension
|
||||
|
||||
@@ -101,27 +101,25 @@ final class AsrTranscriptionTests: XCTestCase {
|
||||
XCTAssertTrue(result.tokenTimings?.isEmpty == true) // No timestamps provided, should be empty array
|
||||
}
|
||||
|
||||
func testProcessTranscriptionResultWithTimings() async {
|
||||
func testProcessTranscriptionResultWithTimestampsAndConfidences() async {
|
||||
await setupMockVocabulary()
|
||||
let tokenIds = [10, 20, 30]
|
||||
let audioSamples = Array(repeating: Float(0), count: 48_000) // 3 seconds
|
||||
let timings = [
|
||||
TokenTiming(token: "hello", tokenId: 10, startTime: 0.0, endTime: 1.0, confidence: 0.9),
|
||||
TokenTiming(token: "world", tokenId: 20, startTime: 1.0, endTime: 2.0, confidence: 0.85),
|
||||
TokenTiming(token: "test", tokenId: 30, startTime: 2.0, endTime: 3.0, confidence: 0.95),
|
||||
]
|
||||
let timestamps = [0, 12, 25]
|
||||
let confidences: [Float] = [0.9, 0.85, 0.95]
|
||||
|
||||
let result = await manager.processTranscriptionResult(
|
||||
tokenIds: tokenIds,
|
||||
timestamps: timestamps,
|
||||
confidences: confidences,
|
||||
encoderSequenceLength: 150,
|
||||
audioSamples: audioSamples,
|
||||
processingTime: 1.2,
|
||||
tokenTimings: timings
|
||||
processingTime: 1.2
|
||||
)
|
||||
|
||||
XCTAssertEqual(result.duration, 3.0, accuracy: 0.01)
|
||||
XCTAssertNotNil(result.tokenTimings)
|
||||
// Note: Actual timing count may differ due to convertTokensWithExistingTimings filtering
|
||||
XCTAssertEqual(result.tokenTimings?.count, 3)
|
||||
}
|
||||
|
||||
func testProcessTranscriptionResultWithTimestamps() async {
|
||||
|
||||
@@ -133,10 +133,10 @@ final class ModelNamesTests: XCTestCase {
|
||||
}
|
||||
|
||||
func testParakeetTdtCtc110mVocabulary() {
|
||||
// tdtCtc110m uses array-format vocabulary
|
||||
// tdtCtc110m uses same vocabulary file (array-format JSON, parsed at load time)
|
||||
let vocabFile = ModelNames.ASR.vocabulary(for: .parakeetTdtCtc110m)
|
||||
XCTAssertEqual(vocabFile, "parakeet_vocab.json")
|
||||
XCTAssertEqual(vocabFile, ModelNames.ASR.vocabularyFileArray)
|
||||
XCTAssertEqual(vocabFile, ModelNames.ASR.vocabularyFile)
|
||||
}
|
||||
|
||||
func testParakeetTdtCtc110mUsesRequiredModelsFused() {
|
||||
|
||||
@@ -42,81 +42,4 @@ final class PerformanceMetricsTests: XCTestCase {
|
||||
let summary = metrics.summary
|
||||
XCTAssertTrue(summary.contains("N/A"), "Summary should show N/A for nil GPU utilization")
|
||||
}
|
||||
|
||||
// MARK: - AggregatedMetrics
|
||||
|
||||
func testAggregatedMetricsSummaryFormatting() {
|
||||
let aggregated = AggregatedMetrics(
|
||||
averageRTFx: 8.5,
|
||||
averageProcessingTime: 1.234,
|
||||
maxMemoryMB: 512.0,
|
||||
sampleCount: 10
|
||||
)
|
||||
|
||||
let summary = aggregated.summary
|
||||
XCTAssertTrue(summary.contains("10 samples"), "Summary should contain sample count")
|
||||
XCTAssertTrue(summary.contains("8.5"), "Summary should contain average RTFx")
|
||||
XCTAssertTrue(summary.contains("1.234"), "Summary should contain average processing time")
|
||||
XCTAssertTrue(summary.contains("512.0"), "Summary should contain max memory")
|
||||
}
|
||||
|
||||
// MARK: - PerformanceMonitor
|
||||
|
||||
func testAggregatedMetricsEmptyReturnsNil() async {
|
||||
let monitor = PerformanceMonitor()
|
||||
let result = await monitor.getAggregatedMetrics()
|
||||
XCTAssertNil(result, "Empty monitor should return nil for aggregated metrics")
|
||||
}
|
||||
|
||||
func testResetClearsMetrics() async throws {
|
||||
let monitor = PerformanceMonitor()
|
||||
|
||||
// Track a session to add metrics
|
||||
_ = try await monitor.trackSession(operation: "test", audioLengthSeconds: 1.0) {
|
||||
return 42
|
||||
}
|
||||
|
||||
// Verify metrics exist
|
||||
let before = await monitor.getAggregatedMetrics()
|
||||
XCTAssertNotNil(before)
|
||||
|
||||
// Reset and verify empty
|
||||
await monitor.reset()
|
||||
let after = await monitor.getAggregatedMetrics()
|
||||
XCTAssertNil(after, "After reset, aggregated metrics should be nil")
|
||||
}
|
||||
|
||||
func testTrackSessionReturnsMetrics() async throws {
|
||||
let monitor = PerformanceMonitor()
|
||||
|
||||
let (result, metrics) = try await monitor.trackSession(
|
||||
operation: "test",
|
||||
audioLengthSeconds: 2.0
|
||||
) {
|
||||
return "hello"
|
||||
}
|
||||
|
||||
XCTAssertEqual(result, "hello")
|
||||
XCTAssertGreaterThanOrEqual(metrics.totalProcessingTime, 0)
|
||||
XCTAssertGreaterThan(metrics.rtfx, 0)
|
||||
}
|
||||
|
||||
func testAggregatedMetricsComputation() async throws {
|
||||
let monitor = PerformanceMonitor()
|
||||
|
||||
for i in 0..<3 {
|
||||
_ = try await monitor.trackSession(
|
||||
operation: "test\(i)",
|
||||
audioLengthSeconds: Float(i + 1)
|
||||
) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
let aggregated = await monitor.getAggregatedMetrics()
|
||||
XCTAssertNotNil(aggregated)
|
||||
XCTAssertEqual(aggregated?.sampleCount, 3)
|
||||
XCTAssertGreaterThan(aggregated!.averageRTFx, 0)
|
||||
XCTAssertGreaterThan(aggregated!.averageProcessingTime, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
func testCreateANEAlignedArrayFloat32() throws {
|
||||
let shape: [NSNumber] = [1, 100]
|
||||
let array = try ANEOptimizer.createANEAlignedArray(
|
||||
let array = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: shape,
|
||||
dataType: .float32
|
||||
)
|
||||
@@ -22,7 +22,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
if !isCI {
|
||||
// Verify memory alignment only in non-CI environment
|
||||
let alignment = ANEOptimizer.aneAlignment
|
||||
let alignment = ANEMemoryUtils.aneAlignment
|
||||
let pointerValue = Int(bitPattern: array.dataPointer)
|
||||
XCTAssertEqual(pointerValue % alignment, 0, "Array should be \(alignment)-byte aligned")
|
||||
}
|
||||
@@ -30,7 +30,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
func testCreateANEAlignedArrayFloat16() throws {
|
||||
let shape: [NSNumber] = [1, 64] // Smaller shape for CI stability
|
||||
let array = try ANEOptimizer.createANEAlignedArray(
|
||||
let array = try ANEMemoryUtils.createAlignedArray(
|
||||
shape: shape,
|
||||
dataType: .float16
|
||||
)
|
||||
@@ -46,7 +46,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
// Verify memory alignment only in non-CI environment
|
||||
let pointerValue = Int(bitPattern: array.dataPointer)
|
||||
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
|
||||
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,9 +56,8 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
func testCalculateOptimalStridesBasic() {
|
||||
let shape: [NSNumber] = [1, 3, 224, 224]
|
||||
let strides = ANEOptimizer.calculateOptimalStrides(
|
||||
for: shape,
|
||||
dataType: .float32
|
||||
let strides = ANEMemoryUtils.calculateOptimalStrides(
|
||||
for: shape
|
||||
)
|
||||
|
||||
XCTAssertEqual(strides.count, shape.count)
|
||||
@@ -71,9 +70,8 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
func testCalculateOptimalStridesWithPadding() {
|
||||
// Test with dimension that needs padding (not multiple of 16)
|
||||
let shape: [NSNumber] = [1, 100] // 100 is not multiple of 16
|
||||
let strides = ANEOptimizer.calculateOptimalStrides(
|
||||
for: shape,
|
||||
dataType: .float32
|
||||
let strides = ANEMemoryUtils.calculateOptimalStrides(
|
||||
for: shape
|
||||
)
|
||||
|
||||
// The stride for the first dimension should account for padding
|
||||
@@ -84,22 +82,10 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
// MARK: - Compute Unit Selection Tests
|
||||
|
||||
func testOptimalComputeUnits() {
|
||||
// All models use CPU+ANE for optimal performance
|
||||
XCTAssertEqual(
|
||||
ANEOptimizer.optimalComputeUnits(for: .encoder),
|
||||
.cpuAndNeuralEngine
|
||||
)
|
||||
|
||||
XCTAssertEqual(
|
||||
ANEOptimizer.optimalComputeUnits(for: .decoder),
|
||||
.cpuAndNeuralEngine
|
||||
)
|
||||
|
||||
XCTAssertEqual(
|
||||
ANEOptimizer.optimalComputeUnits(for: .joint),
|
||||
.cpuAndNeuralEngine
|
||||
)
|
||||
func testDefaultConfigurationComputeUnits() {
|
||||
// Default configuration uses CPU+ANE
|
||||
let config = MLModelConfigurationUtils.defaultConfiguration()
|
||||
XCTAssertEqual(config.computeUnits, .cpuAndNeuralEngine)
|
||||
}
|
||||
|
||||
// MARK: - Zero-Copy View Tests (Removed - causes crashes with memory operations)
|
||||
@@ -115,7 +101,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
float32Array[i] = NSNumber(value: Float(i) * 0.1)
|
||||
}
|
||||
|
||||
let result = try ANEOptimizer.convertToFloat16(float32Array)
|
||||
let result = try ANEMemoryUtils.convertToFloat16(float32Array)
|
||||
|
||||
XCTAssertEqual(result.shape, float32Array.shape)
|
||||
|
||||
@@ -129,7 +115,7 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
|
||||
// Verify ANE alignment only in non-CI environment
|
||||
let pointerValue = Int(bitPattern: result.dataPointer)
|
||||
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
|
||||
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
|
||||
}
|
||||
|
||||
// Verify data conversion accuracy (regardless of CI)
|
||||
@@ -145,11 +131,9 @@ final class ANEOptimizerTests: XCTestCase {
|
||||
let int32Array = try MLMultiArray(shape: [5], dataType: .int32)
|
||||
|
||||
XCTAssertThrowsError(
|
||||
try ANEOptimizer.convertToFloat16(int32Array)
|
||||
try ANEMemoryUtils.convertToFloat16(int32Array)
|
||||
) { error in
|
||||
let nsError = error as NSError
|
||||
XCTAssertEqual(nsError.domain, "ANEOptimizer")
|
||||
XCTAssertEqual(nsError.code, -3)
|
||||
XCTAssertTrue(error is ANEMemoryUtils.ANEMemoryError)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ final class MLArrayCacheTests: XCTestCase {
|
||||
if !isCI {
|
||||
// Verify ANE alignment only in non-CI environment
|
||||
let pointerValue = Int(bitPattern: array.dataPointer)
|
||||
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
|
||||
XCTAssertEqual(pointerValue % ANEMemoryUtils.aneAlignment, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,58 +97,6 @@ final class MLArrayCacheTests: XCTestCase {
|
||||
XCTAssertNotNil(finalArray)
|
||||
}
|
||||
|
||||
// MARK: - Float16 Support
|
||||
|
||||
func testGetFloat16ArrayFromScratch() async throws {
|
||||
let shape: [NSNumber] = [1, 64] // Smaller for CI stability
|
||||
let fp16Array = try await cache.getFloat16Array(shape: shape)
|
||||
|
||||
XCTAssertEqual(fp16Array.shape, shape)
|
||||
|
||||
// In CI, we might get Float32 instead of Float16 for stability
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
if isCI {
|
||||
// In CI, accept either Float16 or Float32
|
||||
XCTAssertTrue(fp16Array.dataType == .float16 || fp16Array.dataType == .float32)
|
||||
} else {
|
||||
XCTAssertEqual(fp16Array.dataType, .float16)
|
||||
|
||||
// Verify ANE alignment only in non-CI environment
|
||||
let pointerValue = Int(bitPattern: fp16Array.dataPointer)
|
||||
XCTAssertEqual(pointerValue % ANEOptimizer.aneAlignment, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func testGetFloat16ArrayFromFloat32() async throws {
|
||||
// Create Float32 array
|
||||
let shape: [NSNumber] = [50] // Smaller for CI stability
|
||||
let float32Array = try MLMultiArray(shape: shape, dataType: .float32)
|
||||
|
||||
// Fill with test values
|
||||
for i in 0..<float32Array.count {
|
||||
float32Array[i] = NSNumber(value: Float(i) * 0.1)
|
||||
}
|
||||
|
||||
// Convert to Float16
|
||||
let float16Array = try await cache.getFloat16Array(shape: shape, from: float32Array)
|
||||
|
||||
XCTAssertEqual(float16Array.shape, shape)
|
||||
|
||||
// In CI, we might get Float32 instead of Float16 for stability
|
||||
let isCI = ProcessInfo.processInfo.environment["CI"] != nil
|
||||
if isCI {
|
||||
// In CI, accept either Float16 or Float32
|
||||
XCTAssertTrue(float16Array.dataType == .float16 || float16Array.dataType == .float32)
|
||||
} else {
|
||||
XCTAssertEqual(float16Array.dataType, .float16)
|
||||
}
|
||||
|
||||
// Verify conversion accuracy (regardless of CI)
|
||||
for i in 0..<min(5, float16Array.count) {
|
||||
XCTAssertEqual(float16Array[i].floatValue, Float(i) * 0.1, accuracy: 0.01)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Pre-warming Tests
|
||||
|
||||
func testPrewarmCache() async {
|
||||
|
||||
Reference in New Issue
Block a user