pyannote community-1 model for offline speaker diarization pipeline (#150)

### Why is this change needed?
<!-- Explain the motivation for this change. What problem does it solve?
-->

Keeping the streaming one around as the VBx and AHC clustering gets
pretty expensive after 30mins of audio and running it constantly gets
expensive. Its still possible to support clustering between files but
will save that for another PR.

Pyannote's Bench mark is around 11% - i increased steps to 0.2s instead
of 0.1 to double the speed but also selective fp16 results in more
operations to run on ANE but also means that we lose some precision.

```
Average DER: 14.95% | Median DER: 10.89% | Average JER: 39.27% | Median JER: 40.74% (collar=0.25s, ignoreOverlap=True)
Average RTFx: 139.63 (from 232 clips)
Metrics summary saved to: /Users/brandonweng/FluidAudioDatasets/voxconverse/metrics/test_metrics_release.json
Completed. New results: 232, Skipped existing: 0, Total attempted: 232
```

See benchmark.md for more info but compared to Pytorch model, we are
100x faster than the CPU version and ~6x faster compared to the mps
backend on mb pro 4

---------

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: Brandon Weng <BrandonWeng@users.noreply.github.com>
Co-authored-by: Alex <36247722+Alex-Wengg@users.noreply.github.com>
Co-authored-by: Alex-Wengg <hanweng9@gmail.com>
This commit is contained in:
Brandon Weng
2025-10-22 15:11:57 -04:00
committed by GitHub
parent bdab5be361
commit 7fd5ac5446
53 changed files with 10134 additions and 316 deletions
@@ -2,6 +2,7 @@
name: apple-neural-performance-expert name: apple-neural-performance-expert
description: Use this agent when you need expert guidance on optimizing neural network operations on Apple platforms, including Metal Performance Shaders (MPS), MLX framework optimization, low-level array operations, GPU kernel optimization, memory management for ML workloads, or performance profiling of neural network code. This agent should be consulted for questions about matrix multiplication optimization, convolution implementations, memory bandwidth optimization, or any performance-critical neural network operations on Apple Silicon.\n\nExamples:\n- <example>\n Context: The user is implementing a custom neural network operation and needs optimization advice.\n user: "I'm implementing a custom attention mechanism in MLX and it's running slower than expected on M2 Max"\n assistant: "I'll use the apple-neural-performance-expert agent to analyze your implementation and suggest optimizations."\n <commentary>\n Since this involves MLX performance optimization on Apple Silicon, the apple-neural-performance-expert is the right choice.\n </commentary>\n</example>\n- <example>\n Context: The user needs help with Metal Performance Shaders for neural network operations.\n user: "How can I optimize batch matrix multiplication using MPS for my transformer model?"\n assistant: "Let me consult the apple-neural-performance-expert agent to provide specific MPS optimization strategies."\n <commentary>\n The question specifically asks about MPS optimization for neural networks, which is this agent's specialty.\n </commentary>\n</example>\n- <example>\n Context: The user is experiencing memory issues with their ML model on Apple devices.\n user: "My model keeps running out of memory on iPhone 15 Pro when processing large batches"\n assistant: "I'll engage the apple-neural-performance-expert agent to analyze memory usage patterns and suggest optimization strategies."\n <commentary>\n Memory optimization for ML workloads on Apple devices requires specialized knowledge this agent possesses.\n </commentary>\n</example> description: Use this agent when you need expert guidance on optimizing neural network operations on Apple platforms, including Metal Performance Shaders (MPS), MLX framework optimization, low-level array operations, GPU kernel optimization, memory management for ML workloads, or performance profiling of neural network code. This agent should be consulted for questions about matrix multiplication optimization, convolution implementations, memory bandwidth optimization, or any performance-critical neural network operations on Apple Silicon.\n\nExamples:\n- <example>\n Context: The user is implementing a custom neural network operation and needs optimization advice.\n user: "I'm implementing a custom attention mechanism in MLX and it's running slower than expected on M2 Max"\n assistant: "I'll use the apple-neural-performance-expert agent to analyze your implementation and suggest optimizations."\n <commentary>\n Since this involves MLX performance optimization on Apple Silicon, the apple-neural-performance-expert is the right choice.\n </commentary>\n</example>\n- <example>\n Context: The user needs help with Metal Performance Shaders for neural network operations.\n user: "How can I optimize batch matrix multiplication using MPS for my transformer model?"\n assistant: "Let me consult the apple-neural-performance-expert agent to provide specific MPS optimization strategies."\n <commentary>\n The question specifically asks about MPS optimization for neural networks, which is this agent's specialty.\n </commentary>\n</example>\n- <example>\n Context: The user is experiencing memory issues with their ML model on Apple devices.\n user: "My model keeps running out of memory on iPhone 15 Pro when processing large batches"\n assistant: "I'll engage the apple-neural-performance-expert agent to analyze memory usage patterns and suggest optimization strategies."\n <commentary>\n Memory optimization for ML workloads on Apple devices requires specialized knowledge this agent possesses.\n </commentary>\n</example>
tools: Task, Bash, Glob, Grep, LS, ExitPlanMode, Read, Edit, MultiEdit, Write, NotebookRead, NotebookEdit, WebFetch, TodoWrite, WebSearch, mcp__deepwiki__read_wiki_structure, mcp__deepwiki__read_wiki_contents, mcp__deepwiki__ask_question tools: Task, Bash, Glob, Grep, LS, ExitPlanMode, Read, Edit, MultiEdit, Write, NotebookRead, NotebookEdit, WebFetch, TodoWrite, WebSearch, mcp__deepwiki__read_wiki_structure, mcp__deepwiki__read_wiki_contents, mcp__deepwiki__ask_question
model: sonnet
--- ---
You are an elite Apple platform performance engineer specializing in neural network optimization. Your expertise spans Metal Performance Shaders (MPS), MLX framework internals, and low-level optimization techniques for mathematical operations on Apple Silicon. You are an elite Apple platform performance engineer specializing in neural network optimization. Your expertise spans Metal Performance Shaders (MPS), MLX framework internals, and low-level optimization techniques for mathematical operations on Apple Silicon.
+203
View File
@@ -0,0 +1,203 @@
name: Offline VBx Pipeline
on:
pull_request:
branches: [main]
types: [opened, synchronize, reopened]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
benchmark:
name: Offline VBx Pipeline Benchmark
runs-on: macos-latest
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Setup Swift 6.1
uses: swift-actions/setup-swift@v2
with:
swift-version: "6.1"
- name: Build package
run: swift build
- name: Run Offline Pipeline Benchmark
id: benchmark
run: |
echo "Running offline VBx pipeline benchmark..."
# Record start time
BENCHMARK_START=$(date +%s)
swift run fluidaudio diarization-benchmark --mode offline --auto-download --single-file ES2004a --output offline_results.json
# Check if results file was generated
if [ -f offline_results.json ]; then
echo "SUCCESS=true" >> $GITHUB_OUTPUT
else
echo "Benchmark failed - no results file generated"
echo "SUCCESS=false" >> $GITHUB_OUTPUT
fi
# Calculate execution time
BENCHMARK_END=$(date +%s)
EXECUTION_TIME=$((BENCHMARK_END - BENCHMARK_START))
EXECUTION_MINS=$((EXECUTION_TIME / 60))
EXECUTION_SECS=$((EXECUTION_TIME % 60))
echo "EXECUTION_TIME=${EXECUTION_MINS}m ${EXECUTION_SECS}s" >> $GITHUB_OUTPUT
timeout-minutes: 30
- name: Show offline_results.json
if: always()
run: |
echo "--- offline_results.json ---"
cat offline_results.json || echo "offline_results.json not found"
echo "-----------------------------"
- name: Extract benchmark metrics with jq
id: extract
run: |
# The output is now an array, so we need to access the first element
DER=$(jq '.[0].der' offline_results.json)
JER=$(jq '.[0].jer' offline_results.json)
RTF=$(jq '.[0].rtfx' offline_results.json)
DURATION="1049" # ES2004a duration in seconds
SPEAKER_COUNT=$(jq '.[0].detectedSpeakers' offline_results.json)
# Extract detailed timing information
TOTAL_TIME=$(jq '.[0].timings.totalProcessingSeconds' offline_results.json)
MODEL_DOWNLOAD_TIME=$(jq '.[0].timings.modelDownloadSeconds' offline_results.json)
MODEL_COMPILE_TIME=$(jq '.[0].timings.modelCompilationSeconds' offline_results.json)
AUDIO_LOAD_TIME=$(jq '.[0].timings.audioLoadingSeconds' offline_results.json)
SEGMENTATION_TIME=$(jq '.[0].timings.segmentationSeconds' offline_results.json)
EMBEDDING_TIME=$(jq '.[0].timings.embeddingExtractionSeconds' offline_results.json)
CLUSTERING_TIME=$(jq '.[0].timings.speakerClusteringSeconds' offline_results.json)
INFERENCE_TIME=$(jq '.[0].timings.totalInferenceSeconds' offline_results.json)
echo "DER=${DER}" >> $GITHUB_OUTPUT
echo "JER=${JER}" >> $GITHUB_OUTPUT
echo "RTF=${RTF}" >> $GITHUB_OUTPUT
echo "DURATION=${DURATION}" >> $GITHUB_OUTPUT
echo "SPEAKER_COUNT=${SPEAKER_COUNT}" >> $GITHUB_OUTPUT
echo "TOTAL_TIME=${TOTAL_TIME}" >> $GITHUB_OUTPUT
echo "MODEL_DOWNLOAD_TIME=${MODEL_DOWNLOAD_TIME}" >> $GITHUB_OUTPUT
echo "MODEL_COMPILE_TIME=${MODEL_COMPILE_TIME}" >> $GITHUB_OUTPUT
echo "AUDIO_LOAD_TIME=${AUDIO_LOAD_TIME}" >> $GITHUB_OUTPUT
echo "SEGMENTATION_TIME=${SEGMENTATION_TIME}" >> $GITHUB_OUTPUT
echo "EMBEDDING_TIME=${EMBEDDING_TIME}" >> $GITHUB_OUTPUT
echo "CLUSTERING_TIME=${CLUSTERING_TIME}" >> $GITHUB_OUTPUT
echo "INFERENCE_TIME=${INFERENCE_TIME}" >> $GITHUB_OUTPUT
- name: Comment PR with Offline Pipeline Results
if: always()
uses: actions/github-script@v7
with:
script: |
const der = parseFloat('${{ steps.extract.outputs.DER }}');
const jer = parseFloat('${{ steps.extract.outputs.JER }}');
const rtf = parseFloat('${{ steps.extract.outputs.RTF }}');
const duration = parseFloat('${{ steps.extract.outputs.DURATION }}').toFixed(1);
const speakerCount = '${{ steps.extract.outputs.SPEAKER_COUNT }}';
const totalTime = parseFloat('${{ steps.extract.outputs.TOTAL_TIME }}');
const inferenceTime = parseFloat('${{ steps.extract.outputs.INFERENCE_TIME }}');
const modelDownloadTime = parseFloat('${{ steps.extract.outputs.MODEL_DOWNLOAD_TIME }}');
const modelCompileTime = parseFloat('${{ steps.extract.outputs.MODEL_COMPILE_TIME }}');
const audioLoadTime = parseFloat('${{ steps.extract.outputs.AUDIO_LOAD_TIME }}');
const segmentationTime = parseFloat('${{ steps.extract.outputs.SEGMENTATION_TIME }}');
const embeddingTime = parseFloat('${{ steps.extract.outputs.EMBEDDING_TIME }}');
const clusteringTime = parseFloat('${{ steps.extract.outputs.CLUSTERING_TIME }}');
const executionTime = '${{ steps.benchmark.outputs.EXECUTION_TIME }}' || 'N/A';
let comment = '## Offline VBx Pipeline Results\n\n';
comment += '### Speaker Diarization Performance (VBx Batch Mode)\n';
comment += '_Optimal clustering with Hungarian algorithm for maximum accuracy_\n\n';
comment += '| Metric | Value | Target | Status | Description |\n';
comment += '|--------|-------|--------|---------|-------------|\n';
comment += `| **DER** | **${der.toFixed(1)}%** | <20% | ${der < 20 ? '✅' : '⚠️'} | Diarization Error Rate (lower is better) |\n`;
comment += `| **JER** | **${jer.toFixed(1)}%** | <18% | ${jer < 18 ? '✅' : '⚠️'} | Jaccard Error Rate |\n`;
comment += `| **RTFx** | **${rtf.toFixed(2)}x** | >1.0x | ${rtf > 1.0 ? '✅' : '⚠️'} | Real-Time Factor (higher is faster) |\n\n`;
comment += '### Offline VBx Pipeline Timing Breakdown\n';
comment += '_Time spent in each stage of batch diarization_\n\n';
comment += '| Stage | Time (s) | % | Description |\n';
comment += '|-------|----------|---|-------------|\n';
comment += `| Model Download | ${modelDownloadTime.toFixed(3)} | ${(modelDownloadTime/totalTime*100).toFixed(1)} | Fetching diarization models |\n`;
comment += `| Model Compile | ${modelCompileTime.toFixed(3)} | ${(modelCompileTime/totalTime*100).toFixed(1)} | CoreML compilation |\n`;
comment += `| Audio Load | ${audioLoadTime.toFixed(3)} | ${(audioLoadTime/totalTime*100).toFixed(1)} | Loading audio file |\n`;
comment += `| Segmentation | ${segmentationTime.toFixed(3)} | ${(segmentationTime/totalTime*100).toFixed(1)} | VAD + speech detection |\n`;
comment += `| Embedding | ${embeddingTime.toFixed(3)} | ${(embeddingTime/totalTime*100).toFixed(1)} | Speaker embedding extraction |\n`;
comment += `| Clustering (VBx) | ${clusteringTime.toFixed(3)} | ${(clusteringTime/totalTime*100).toFixed(1)} | Hungarian algorithm + VBx clustering |\n`;
comment += `| **Total** | **${totalTime.toFixed(3)}** | **100** | **Full VBx pipeline** |\n\n`;
comment += '### Speaker Diarization Research Comparison\n';
comment += '_Offline VBx achieves competitive accuracy with batch processing_\n\n';
comment += '| Method | DER | Mode | Description |\n';
comment += '|--------|-----|------|-------------|\n';
comment += '| **FluidAudio (Offline)** | **' + der.toFixed(1) + '%** | **VBx Batch** | **On-device CoreML with optimal clustering** |\n';
comment += '| FluidAudio (Streaming) | 17.7% | Chunk-based | First-occurrence speaker mapping |\n';
comment += '| Research baseline | 18-30% | Various | Standard dataset performance |\n\n';
comment += '**Pipeline Details**:\n';
comment += '- **Mode**: Offline VBx with Hungarian algorithm for optimal speaker-to-cluster assignment\n';
comment += '- **Segmentation**: VAD-based voice activity detection\n';
comment += '- **Embeddings**: WeSpeaker-compatible speaker embeddings\n';
comment += '- **Clustering**: PowerSet with VBx refinement\n';
comment += '- **Accuracy**: Higher than streaming due to optimal post-hoc mapping\n\n';
comment += `<sub>🎯 **Offline VBx Test** • AMI Corpus ES2004a • ${duration}s meeting audio • ${inferenceTime.toFixed(1)}s processing • Test runtime: ${executionTime} • ${new Date().toLocaleString('en-US', { timeZone: 'America/New_York', year: 'numeric', month: '2-digit', day: '2-digit', hour: '2-digit', minute: '2-digit', hour12: true })} EST</sub>\n\n`;
// Add hidden identifier for reliable comment detection
comment += '<!-- fluidaudio-offline-pipeline -->';
try {
// First, try to find existing benchmark comment
const comments = await github.rest.issues.listComments({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
});
// Look for existing offline pipeline comment (identified by the hidden tag)
const existingComment = comments.data.find(comment => {
const isBot = comment.user.type === 'Bot' ||
comment.user.login === 'github-actions[bot]' ||
comment.user.login.includes('[bot]');
const hasIdentifier = comment.body.includes('<!-- fluidaudio-offline-pipeline -->');
const hasHeader = comment.body.includes('## Offline VBx Pipeline Results');
return isBot && (hasIdentifier || hasHeader);
});
if (existingComment) {
// Update existing comment
await github.rest.issues.updateComment({
comment_id: existingComment.id,
owner: context.repo.owner,
repo: context.repo.repo,
body: comment
});
console.log('Successfully updated existing offline pipeline comment');
} else {
// Create new comment if none exists
await github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: comment
});
console.log('Successfully posted new offline pipeline results comment');
}
} catch (error) {
console.error('Failed to update/post comment:', error.message);
// Don't fail the workflow just because commenting failed
}
+1
View File
@@ -35,6 +35,7 @@ swift format --in-place --recursive --configuration .swift-format Sources/ Tests
- Error handling: Use proper Swift error handling, no force unwrapping in production - Error handling: Use proper Swift error handling, no force unwrapping in production
- Documentation: Triple-slash comments (`///`) for public APIs - Documentation: Triple-slash comments (`///`) for public APIs
- Thread safety: Use actors, `@MainActor`, or proper locking - never `@unchecked Sendable` - Thread safety: Use actors, `@MainActor`, or proper locking - never `@unchecked Sendable`
- Control flow: Prefer flattened if statements with early returns/continues over nested if statements. Use guard statements and inverted conditions to exit early. Nested if statements should be absolutely avoided.
## Clean code ## Clean code
+27 -7
View File
@@ -36,13 +36,25 @@ FluidAudio is a comprehensive Swift framework for local, low-latency audio proce
- **DO NOT** implement alternatives without asking - **DO NOT** implement alternatives without asking
- Only after your approval: Implementation, then explanation of results - Only after your approval: Implementation, then explanation of results
### Code Formatting ### Code Style and Formatting
- **Swift Format**: This project uses swift-format for consistent code style - **Swift Format**: This project uses swift-format for consistent code style
- **Configuration**: See `.swift-format` for style rules - **Configuration**: See `.swift-format` for style rules
- **Auto-formatting**: PRs are automatically checked for formatting compliance - **Auto-formatting**: PRs are automatically checked for formatting compliance
- **Local formatting**: Run `swift format --in-place --recursive --configuration .swift-format Sources/ Tests/` - **Local formatting**: Run `swift format --in-place --recursive --configuration .swift-format Sources/ Tests/`
#### Style Guidelines
- **Line length**: 120 characters
- **Indentation**: 4 spaces
- **Import order**: `import CoreML`, `import Foundation`, `import OSLog` (OrderedImports rule)
- **Naming conventions**:
- lowerCamelCase for variables/functions
- UpperCamelCase for types
- **Error handling**: Use proper Swift error handling, no force unwrapping in production
- **Documentation**: Triple-slash comments (`///`) for public APIs
- **Thread safety**: Use actors, `@MainActor`, or proper locking - never `@unchecked Sendable`
- **Control flow**: Prefer flattened if statements with early returns/continues over nested if statements. Use guard statements and inverted conditions to exit early. Nested if statements should be absolutely avoided to improve readability and reduce cognitive complexity.
## Current Performance Status ## Current Performance Status
- **Achieved**: 17.7% DER - **Achieved**: 17.7% DER
@@ -262,6 +274,16 @@ The project uses GitHub Actions with the following workflows:
- **Indentation**: 4 spaces - **Indentation**: 4 spaces
- **Formatting Rules**: Automatic via swift-format, CI enforced - **Formatting Rules**: Automatic via swift-format, CI enforced
## User Preferences
- Never start responses with positive re-affirming text like "You're absolutely right!", "Good change!", "Excellent progress!", or similar
- Get straight to the point with technical facts
- For debugging, use print statements and delete them at the end when instructed
- Never create fallbacks or simplified solutions that don't actually solve the problem
- Always go for the proper solution over the "simplified" solution
- When asked to implement something specific, DO IT FIRST before explaining why it might not be optimal - implementation first, explanation second
- Just do as instructed - don't try to over-do things that aren't asked
## Development Guidelines ## Development Guidelines
1. **Testing**: Always run benchmarks on multiple files for validation 1. **Testing**: Always run benchmarks on multiple files for validation
@@ -271,8 +293,9 @@ The project uses GitHub Actions with the following workflows:
5. **Thread Safety**: Never use `@unchecked Sendable` - implement proper synchronization 5. **Thread Safety**: Never use `@unchecked Sendable` - implement proper synchronization
6. **Follow Instructions**: When the user asks to implement something specific, DO IT FIRST before explaining why it might not be optimal. Implementation first, explanation second. 6. **Follow Instructions**: When the user asks to implement something specific, DO IT FIRST before explaining why it might not be optimal. Implementation first, explanation second.
7. **Avoid Deprecated Code**: Do not add support for deprecated models or features unless explicitly requested. Keep the codebase clean by only supporting current versions. 7. **Avoid Deprecated Code**: Do not add support for deprecated models or features unless explicitly requested. Keep the codebase clean by only supporting current versions.
8. **Git Operations**: NEVER run `git push` unless explicitly requested by the user. Only commit when asked. 8. **Testing Policy**: ONLY add or run tests when explicitly requested by the user
8. **Code Formatting**: All code must pass swift-format checks before merge 9. **Git Operations**: NEVER run `git push` unless explicitly requested by the user. Only commit when asked.
10. **Code Formatting**: All code must pass swift-format checks before merge
## Next Steps ## Next Steps
@@ -317,7 +340,4 @@ swift test --filter EdgeCaseTests
- **Diarization**: [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) - **Diarization**: [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1)
- **VAD CoreML**: [FluidInference/silero-vad-coreml](https://huggingface.co/FluidInference/silero-vad-coreml) - **VAD CoreML**: [FluidInference/silero-vad-coreml](https://huggingface.co/FluidInference/silero-vad-coreml)
- **ASR Models**: [FluidInference/parakeet-tdt-0.6b-v3-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v3-coreml) - **ASR Models**: [FluidInference/parakeet-tdt-0.6b-v3-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v3-coreml)
- **Test Data**: [alexwengg/musan_mini*](https://huggingface.co/datasets/alexwengg) variants - **Test Data**: [alexwengg/musan_mini*](https://huggingface.co/datasets/alexwengg) variants
- remember to never start with "You're absolutely right!"
- remember to never start with things like "Good change!" or any positive re-affirming text. Just get straight to the point.
- remember to not use config debug, for dbeugging just print and then delete at the very end when the user tells you
+48 -4
View File
@@ -32,6 +32,43 @@ Main class for speaker diarization and "who spoke when" analysis.
- `DiarizerConfig`: Clustering threshold, minimum durations, activity thresholds - `DiarizerConfig`: Clustering threshold, minimum durations, activity thresholds
- Optimal threshold: 0.7 (17.7% DER on AMI dataset) - Optimal threshold: 0.7 (17.7% DER on AMI dataset)
### OfflineDiarizerManager
Full batch pipeline that mirrors the pyannote/Core ML exporter (powerset segmentation + VBx clustering).
> Requires macOS 14 / iOS 17 or later because the manager relies on Swift Concurrency features and C++ clustering shims that are unavailable on older OS releases.
**Key Methods:**
- `init(config: OfflineDiarizerConfig = .default)`
- Creates manager with configuration
- `prepareModels(directory:configuration:forceRedownload:) async throws`
- Downloads / compiles the Core ML bundles as needed and records timing metadata. Call once before processing when you don't already have `OfflineDiarizerModels`.
- `initialize(models: OfflineDiarizerModels)`
- Initializes with models containing segmentation, embedding, and PLDA components (useful when you hydrate the bundles yourself).
- `process(audio: [Float]) async throws -> DiarizationResult`
- Runs the full 10s window pipeline: segmentation → soft mask interpolation → embedding → VBx → timeline reconstruction.
- `process(audioSource: StreamingAudioSampleSource, audioLoadingSeconds: TimeInterval) async throws -> DiarizationResult`
- Streams audio from disk-backed sources without materializing the entire buffer in memory. Pair with `StreamingAudioSourceFactory` for large meetings.
**Supporting Types:**
- `OfflineDiarizerConfig`
- Mirrors pyannote `config.yaml` (`clusteringThreshold`, `Fa`, `Fb`, `maxVBxIterations`, `minDurationOn/off`, batch sizes, logging flags).
- `SegmentationRunner`
- Batches 160k-sample chunks through the segmentation model (589 frames per chunk).
- `Binarization`
- Converts log probabilities to soft VAD weights while retaining binary masks for diagnostics.
- `WeightInterpolation`
- Reimplements `scipy.ndimage.zoom` (half-pixel offsets) so 589-frame weights align with the embedding models pooling stride.
- `EmbeddingRunner`
- Runs the FBANK frontend + embedding backend, resamples masks to 589 frames, and emits 256-d L2-normalized embeddings.
- `PLDAScoring` / `VBxClustering`
- Apply the exported PLDA transforms and iterative VBx refinement to group embeddings into speakers.
- `TimelineReconstruction`
- Derives timestamps directly from the segmentation frame count and `OfflineDiarizerConfig.windowDuration`, then enforces minimum gap/duration constraints.
- `StreamingAudioSourceFactory`
- Creates disk-backed or in-memory `StreamingAudioSampleSource` instances so large meetings never require fully materialized `[Float]` buffers.
Use `OfflineDiarizerManager` when you need offline DER parity or want to run the new CLI offline mode (`fluidaudio process --mode offline`, `fluidaudio diarization-benchmark --mode offline`).
## Voice Activity Detection ## Voice Activity Detection
### VadManager ### VadManager
@@ -70,10 +107,12 @@ Recommended `defaultThreshold` ranges depend on your acoustic conditions:
Automatic speech recognition using Parakeet TDT models (v2 English-only, v3 multilingual). Automatic speech recognition using Parakeet TDT models (v2 English-only, v3 multilingual).
**Key Methods:** **Key Methods:**
- `transcribe(_:source:) throws -> AsrTranscription` - `transcribe(_:source:) async throws -> ASRResult`
- Process complete audio and return transcription - Accepts `[Float]` samples already converted to 16 kHz mono; returns transcription text, confidence, and token timings.
- Parameters: `RandomAccessCollection<Float>` samples, `AudioSource` (microphone/system) - `transcribe(_ url: URL, source:) async throws -> ASRResult`
- Returns: `AsrTranscription` with text, confidence, and timing - Loads the file directly and performs format conversion internally (`AudioConverter`).
- `transcribe(_ buffer: AVAudioPCMBuffer, source:) async throws -> ASRResult`
- Convenience overload for capture pipelines that already produce PCM buffers.
- `initialize(models:) async throws` - `initialize(models:) async throws`
- Load and initialize ASR models (automatic download if needed) - Load and initialize ASR models (automatic download if needed)
@@ -91,6 +130,11 @@ Automatic speech recognition using Parakeet TDT models (v2 English-only, v3 mult
- Convert a buffer to 16kHz mono (stateless conversion) - Convert a buffer to 16kHz mono (stateless conversion)
- `AudioSource`: `.microphone` or `.system` for different processing paths - `AudioSource`: `.microphone` or `.system` for different processing paths
> **Warning:** Avoid hand-decoding audio payloads (e.g., truncating WAV headers or treating bytes as raw `Int16` samples).
> The Core ML models require correctly resampled 16 kHz mono Float32 tensors; manual parsing will silently corrupt input when
> formats carry metadata chunks, different bit depths, stereo channels, or compression. Always route files and live buffers
> through `AudioConverter` before calling `AsrManager.transcribe`.
**Performance:** **Performance:**
- Real-time factor: ~120x on M4 Pro (processes 1min audio in 0.5s) - Real-time factor: ~120x on M4 Pro (processes 1min audio in 0.5s)
- Languages: 25 European languages supported - Languages: 25 European languages supported
+20
View File
@@ -40,6 +40,26 @@ Task {
} }
``` ```
> **Important:** Do not parse WAV/PCM bytes by hand (e.g., slicing headers or assuming 16-bit samples).
> Always convert with `AudioConverter` so differing bit depths, channel layouts, metadata chunks,
> or compressed formats (MP3/M4A/FLAC) get normalized to the 16 kHz mono Float32 tensors that Parakeet expects.
> Manually decoded buffers frequently contain garbage values, which shows up as empty transcripts even though the models load successfully.
### Transcribing directly from a file URL
If you already have an audio file on disk you can skip manual sample loading—`AsrManager.transcribe(_ url:source:)`
handles format conversion internally via `AudioConverter`.
```swift
let models = try await AsrModels.downloadAndLoad(version: .v3)
let asrManager = AsrManager()
try await asrManager.initialize(models: models)
let audioURL = URL(fileURLWithPath: "/path/to/audio.wav")
let result = try await asrManager.transcribe(audioURL, source: .system)
print(result.text)
```
## Manual model loading ## Manual model loading
Working offline? Follow the [Manual Model Loading guide](ManualModelLoading.md) to stage the CoreML bundles and call `AsrModels.load` without triggering HuggingFace downloads. Working offline? Follow the [Manual Model Loading guide](ManualModelLoading.md) to stage the CoreML bundles and call `AsrModels.load` without triggering HuggingFace downloads.
+29 -1
View File
@@ -222,5 +222,33 @@ swift run fluidaudio vad-benchmark --dataset musan-full --num-files all --thresh
[23:02:35.744] [INFO] [VAD] Results saved to: vad_benchmark_results.json [23:02:35.744] [INFO] [VAD] Results saved to: vad_benchmark_results.json
``` ```
## Speaker Diarization ## Speaker Diarization
The offline version uses the community-1 model, the online version uses the legacy speaker-diarization-3.1 model.
### Offline diarzing pipeline
For slightly ~1.2% worse DER we default to a higher step ratio segmentation duration than the baseline community-1 pipeline. This allows us to get nearly ~2x the speed (as expected because we're processing 1/2 of the embeddings). For highly critical use cases, one may should use step ratio = 0.1 and minSegmentDurationSeconds = 0.0
Running on the full voxconverse benchmark:
```bash
StepRatio = 0.2, minSegmentDurationSeconds= 1.0
Average DER: 15.07% | Median DER: 10.70% | Average JER: 39.40% | Median JER: 40.95% (collar=0.25s, ignoreOverlap=True)
Average RTFx: 122.06 (from 232 clips)
Completed. New results: 232, Skipped existing: 0, Total attempted: 232
Step Ratio 2, min turation 1.0
StepRatio = 0.1, minSegmentDurationSeconds= 0
Average DER: 13.89% | Median DER: 10.49% | Average JER: 42.84% | Median JER: 43.30% (collar=0.25s, ignoreOverlap=True)
Average RTFx: 64.75 (from 232 clips)
Completed. New results: 232, Skipped existing: 0, Total attempted: 232
Step Ratio 1, min duration 0 (edited)
```
Note that the baseline pytorch version is ~11% DER, we lost some precision dropping down to fp16 precision in order to run most of the emebdding model on neural engine. But as a result, we significantly out perform the baseline `mps` backend as well. the pyannote-community-1 on cpu is ~1.5-2 RTFx, on mps, it's ~20-25 RTFx.
### Streaming/online Diarization
This is more tricky and honestly a lot more fragile to clustering. Expect +10-15% worse DER for the streaming implementation. Only use this when you critically need realtime streaming speaker diarization. In most cases, offline is more than enough for most applications.
+11
View File
@@ -43,8 +43,19 @@ swift run fluidaudio diarization-benchmark --single-file ES2004a \
# Balanced throughput/quality (~10s chunks with 5s overlap) # Balanced throughput/quality (~10s chunks with 5s overlap)
swift run fluidaudio diarization-benchmark --dataset ami-sdm \ swift run fluidaudio diarization-benchmark --dataset ami-sdm \
--chunk-seconds 10 --overlap-seconds 5 --chunk-seconds 10 --overlap-seconds 5
# Run the full VBx offline pipeline
swift run fluidaudio diarization-benchmark --mode offline --dataset ami-sdm --threshold 0.6
# Process a single file with streaming vs. offline inference
swift run fluidaudio process meeting.wav --mode streaming --threshold 0.7
swift run fluidaudio process meeting.wav --mode offline --threshold 0.6 --debug
``` ```
- `--mode offline` switches the CLI to `OfflineDiarizerManager`, running the full VBx pipeline with PLDA refinement. Expect DER ≈ 1820% on AMI-SDM with threshold 0.6.
- Add `--rttm /path/to/ground_truth.rttm` to `process` to compute DER/JER in-place, or `--export-embeddings embeddings.json` for debugging speaker vectors.
- GitHub Actions workflow `offline-pipeline.yml` replays `fluidaudio diarization-benchmark --mode offline --single-file ES2004a` on every PR so failures in model downloads or clustering logic are caught early.
## VAD ## VAD
```bash ```bash
+97 -2
View File
@@ -107,6 +107,101 @@ let config = DiarizerConfig(
let diarizer = DiarizerManager(config: config) let diarizer = DiarizerManager(config: config)
``` ```
### Offline VBx Pipeline (Batch Diarization)
> Requires macOS 14 / iOS 17 or later. The offline stack uses native C++ clustering and AsyncStream coordination that are unavailable on older OS releases.
When you need full parity with the pyannote/Core ML exporter (powerset segmentation + VBx clustering), use `OfflineDiarizerManager`. It orchestrates segmentation, soft mask interpolation, WeSpeaker embedding extraction, PLDA/VBx clustering, and timeline reconstruction in one place:
```swift
import FluidAudio
let config = OfflineDiarizerConfig()
let manager = OfflineDiarizerManager(config: config)
try await manager.prepareModels() // Downloads + compiles Core ML bundles when missing
let samples = try AudioConverter().resampleAudioFile(path: "meeting.wav")
let result = try await manager.process(audio: samples)
for segment in result.segments {
print("\(segment.speakerId)\(segment.startTimeSeconds)s \(segment.endTimeSeconds)s")
}
```
For file-based processing, use the memory-mapped streaming API which automatically handles large audio files efficiently:
```swift
let url = URL(fileURLWithPath: "meeting.wav")
let result = try await manager.process(url)
```
The file-based API internally uses memory-mapped streaming to avoid materializing the entire buffer in memory.
The offline controller mirrors the reference pipeline:
- **Segmentation:** `SegmentationRunner` feeds 10s/160k sample chunks through the Core ML segmentation model. Each chunk yields 589 frame-level log probabilities over the 7 local powerset classes.
- **Binarization:** `Binarization.logProbsToWeights` converts log probabilities to soft VAD weights; binary masks are still available for diagnostics.
- **Weight interpolation:** `WeightInterpolation` applies the same half-pixel mapping as `scipy.ndimage.zoom`, preserving the Core ML exporters alignment when resampling 589-frame masks to the embedding models pooling rate.
- **Embedding extraction:** `EmbeddingRunner` batches audio + resampled weights and returns L2-normalized 256-d embeddings.
- **VBx clustering:** `VBxClustering` (with `AHCClustering` warm start and `PLDAScoring`) runs the full VBx refinement loop using the JSON parameters exported with the model bundle.
- **Timeline reconstruction:** `TimelineReconstruction` now derives frame duration from the actual segmentation output and `OfflineDiarizerConfig.windowDuration`, ensuring timestamps stay correct if you swap in models with different hop sizes.
`OfflineDiarizerConfig` groups knobs by pipeline stage:
- `segmentation`: Window length (default 10 s), step ratio, min on/off durations, and sample rate. These must align with the exported Core ML segmentation model.
- `embedding`: Batch size and overlap handling. Keep `excludeOverlap` enabled for community-1 style powerset outputs.
- `clustering`: The VBx warm-start threshold plus pyannote's Fa/Fb priors.
- `vbx`: Max iterations and convergence tolerance for the refinement loop.
- `postProcessing`: Minimum gap duration when stitching segments back together.
- `export`: Optional `embeddingsPath` for dumping per-speaker vectors to JSON.
`prepareModels` captures Core ML compilation timings (and download durations when a fresh fetch is needed), so `DiarizationResult.timings` reflects audio loading, segmentation, embedding, clustering, and post-processing costs in one place. Per-speaker embeddings are exposed in `speakerDatabase` for downstream analytics without toggling debug flags.
#### CLI shortcut
The CLI exposes the same controller via `fluidaudio process` and the diarization benchmark tooling:
```bash
swift run fluidaudio process meeting.wav --mode offline --threshold 0.6 --debug
swift run fluidaudio diarization-benchmark --mode offline --dataset ami-sdm --threshold 0.6 --auto-download
```
Add `--rttm path/to/meeting.rttm` when you have ground-truth annotations to emit DER/JER directly on the console, or `--export-embeddings embeddings.json` to inspect cluster assignments. The GitHub Actions workflow [`offline-pipeline.yml`](../.github/workflows/offline-pipeline.yml) executes the single-file AMI benchmark on every PR, keeping downloads, PLDA transforms, and VBx clustering guard-railed.
Both commands reuse the shared model cache (`OfflineDiarizerModels.defaultModelsDirectory()`) and emit JSON payloads compatible with the streaming pipeline.
#### Advanced: Manual Audio Source Control
For use cases requiring fine-grained control over memory management or audio loading, you can manually construct the audio source using `StreamingAudioSourceFactory`:
```swift
import FluidAudio
let config = OfflineDiarizerConfig()
let manager = OfflineDiarizerManager(config: config)
try await manager.prepareModels()
let factory = StreamingAudioSourceFactory()
let (source, loadDuration) = try factory.makeDiskBackedSource(
from: URL(fileURLWithPath: "meeting.wav"),
targetSampleRate: config.segmentation.sampleRate
)
defer { source.cleanup() }
let result = try await manager.process(
audioSource: source,
audioLoadingSeconds: loadDuration
)
```
This approach is useful when you need to:
- Process the same file multiple times without reloading
- Measure audio loading time separately from diarization time
- Implement custom cleanup or caching logic
For most use cases, the simpler `manager.process(url)` API is recommended.
## Streaming/Real-time Processing ## Streaming/Real-time Processing
Process audio in chunks for real-time applications: Process audio in chunks for real-time applications:
@@ -455,8 +550,8 @@ swift run fluidaudio diarization-benchmark --single-file ES2004a
| Property | Type | Description | | Property | Type | Description |
|----------|------|-------------| |----------|------|-------------|
| `segments` | `[TimedSpeakerSegment]` | Speaker segments with timing | | `segments` | `[TimedSpeakerSegment]` | Speaker segments with timing |
| `speakerDatabase` | `[String: [Float]]?` | Speaker embeddings (debug mode) | | `speakerDatabase` | `[String: [Float]]?` | Speaker embeddings keyed by speaker ID |
| `timings` | `PipelineTimings?` | Processing timings (debug mode) | | `timings` | `PipelineTimings?` | Processing timings for the diarization pass |
## Requirements ## Requirements
+52 -38
View File
@@ -8,52 +8,66 @@ Pod::Spec.new do |spec|
state-of-the-art speaker diarization, transcription, and voice activity detection state-of-the-art speaker diarization, transcription, and voice activity detection
via open-source models that can be integrated with just a few lines of code. via open-source models that can be integrated with just a few lines of code.
DESC DESC
spec.homepage = "https://github.com/FluidInference/FluidAudio" spec.homepage = "https://github.com/FluidInference/FluidAudio"
spec.license = { :type => "MIT", :file => "LICENSE" } spec.license = { :type => "Apache 2.0", :file => "LICENSE" }
spec.author = { "FluidInference" => "info@fluidinference.com" } spec.author = { "FluidInference" => "info@fluidinference.com" }
spec.ios.deployment_target = "17.0" spec.ios.deployment_target = "17.0"
spec.osx.deployment_target = "14.0" spec.osx.deployment_target = "14.0"
spec.source = { :git => "https://github.com/FluidInference/FluidAudio.git", :tag => "v#{spec.version}" } spec.source = { :git => "https://github.com/FluidInference/FluidAudio.git", :tag => "v#{spec.version}" }
spec.source_files = "Sources/FluidAudio/**/*.swift"
# iOS Configuration
# Exclude TTS module from iOS builds to avoid ESpeakNG xcframework linking issues.
# CocoaPods has known limitations with vendored xcframeworks during pod lib lint on iOS:
# the framework symbols aren't properly linked in the temporary build environment,
# causing "Undefined symbols" linker errors even though the binary is valid.
# iOS builds include: ASR (speech recognition), Diarization, and VAD (voice activity detection).
spec.ios.exclude_files = "Sources/FluidAudio/TextToSpeech/**/*"
spec.ios.frameworks = "CoreML", "AVFoundation", "Accelerate", "UIKit"
# macOS Configuration
# ESpeakNG framework is only vendored for macOS in the podspec (not a framework limitation).
# The xcframework supports iOS, but CocoaPods fails to link it during iOS validation.
# This enables TTS (text-to-speech) functionality with G2P (grapheme-to-phoneme) conversion.
# macOS builds include: ASR, Diarization, VAD, and TTS with ESpeakNG support.
spec.osx.vendored_frameworks = "Sources/FluidAudio/Frameworks/ESpeakNG.xcframework"
spec.osx.frameworks = "CoreML", "AVFoundation", "Accelerate", "Cocoa"
spec.osx.pod_target_xcconfig = {
'ARCHS[sdk=macosx*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64'
}
spec.osx.user_target_xcconfig = {
'ARCHS[sdk=macosx*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64'
}
spec.swift_versions = ["5.10"] spec.swift_versions = ["5.10"]
# Enable module definition for proper framework imports
spec.user_target_xcconfig = {
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64'
}
spec.pod_target_xcconfig = { spec.pod_target_xcconfig = {
'DEFINES_MODULE' => 'YES', 'DEFINES_MODULE' => 'YES',
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64' 'ARCHS[sdk=macosx*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64',
'ARCHS[sdk=iphonesimulator*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=iphonesimulator*]' => 'i386 x86_64',
'ARCHS[sdk=iphoneos*]' => 'arm64'
} }
spec.user_target_xcconfig = {
'ARCHS[sdk=macosx*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=macosx*]' => 'x86_64',
'ARCHS[sdk=iphonesimulator*]' => 'arm64',
'EXCLUDED_ARCHS[sdk=iphonesimulator*]' => 'i386 x86_64',
'ARCHS[sdk=iphoneos*]' => 'arm64'
}
spec.subspec "FastClusterWrapper" do |wrapper|
wrapper.requires_arc = false
wrapper.source_files = "Sources/FastClusterWrapper/**/*.{cpp,h,hpp}"
wrapper.public_header_files = "Sources/FastClusterWrapper/include/FastClusterWrapper.h"
wrapper.private_header_files = "Sources/FastClusterWrapper/fastcluster_internal.hpp"
wrapper.header_mappings_dir = "Sources/FastClusterWrapper"
wrapper.pod_target_xcconfig = {
'CLANG_CXX_LANGUAGE_STANDARD' => 'c++17'
}
end
spec.subspec "Core" do |core|
core.dependency "#{spec.name}/FastClusterWrapper"
core.source_files = "Sources/FluidAudio/**/*.swift"
# iOS Configuration
# Exclude TTS module from iOS builds to avoid ESpeakNG xcframework linking issues.
# CocoaPods has known limitations with vendored xcframeworks during pod lib lint on iOS:
# the framework symbols aren't properly linked in the temporary build environment,
# causing "Undefined symbols" linker errors even though the binary is valid.
# iOS builds include: ASR (speech recognition), Diarization, and VAD (voice activity detection).
core.ios.exclude_files = "Sources/FluidAudio/TextToSpeech/**/*"
core.ios.frameworks = "CoreML", "AVFoundation", "Accelerate", "UIKit"
# macOS Configuration
# ESpeakNG framework is only vendored for macOS in the podspec (not a framework limitation).
# The xcframework supports iOS, but CocoaPods fails to link it during iOS validation.
# This enables TTS (text-to-speech) functionality with G2P (grapheme-to-phoneme) conversion.
# macOS builds include: ASR, Diarization, VAD, and TTS with ESpeakNG support.
core.osx.vendored_frameworks = "Sources/FluidAudio/Frameworks/ESpeakNG.xcframework"
core.osx.frameworks = "CoreML", "AVFoundation", "Accelerate", "Cocoa"
end
spec.default_subspecs = ["Core"]
end end
+2
View File
@@ -11,6 +11,8 @@ This guide defines how Codex should create comprehensive plans before tackling a
## Plan Creation Workflow ## Plan Creation Workflow
Before drafting the plan, conduct any necessary research using Context7 or DeepWiki (or both). Summarize the key findings and links from that research inside the plan so reviewers can trace assumptions. During execution, keep leveraging Context7/DeepWiki whenever you need additional context or verification, and note their use in updates.
1. **Restate the Mission** 1. **Restate the Mission**
- Summarize the problem in your own words. - Summarize the problem in your own words.
- List explicit goals, non-goals, stakeholders (CLI users, diarizer pipeline, model loaders), and success criteria (latency/accuracy targets, UX expectations, guardrails). - List explicit goals, non-goals, stakeholders (CLI users, diarizer pipeline, model loaders), and success criteria (latency/accuracy targets, UX expectations, guardrails).
+12 -3
View File
@@ -26,7 +26,8 @@ let package = Package(
.target( .target(
name: "FluidAudio", name: "FluidAudio",
dependencies: [ dependencies: [
"ESpeakNG" "ESpeakNG",
"FastClusterWrapper",
], ],
path: "Sources/FluidAudio", path: "Sources/FluidAudio",
exclude: ["Frameworks"], exclude: ["Frameworks"],
@@ -36,8 +37,16 @@ let package = Package(
.unsafeFlags([ .unsafeFlags([
"-Xcc", "-DACCELERATE_NEW_LAPACK", "-Xcc", "-DACCELERATE_NEW_LAPACK",
"-Xcc", "-DACCELERATE_LAPACK_ILP64", "-Xcc", "-DACCELERATE_LAPACK_ILP64",
]) ]),
] ]
),
.target(
name: "FastClusterWrapper",
path: "Sources/FastClusterWrapper",
publicHeadersPath: "include",
cxxSettings: [
.unsafeFlags(["-std=c++17"])
]
), ),
.executableTarget( .executableTarget(
name: "FluidAudioCLI", name: "FluidAudioCLI",
+54 -7
View File
@@ -29,7 +29,7 @@ Want to convert your own model? Check [möbius](https://github.com/FluidInferenc
## Highlights ## Highlights
- **Automatic Speech Recognition (ASR)**: Parakeet TDT v3 (0.6b) for transcription; supports all 25 European languages - **Automatic Speech Recognition (ASR)**: Parakeet TDT v3 (0.6b) for transcription; supports all 25 European languages
- **Speaker Diarization**: Speaker separation with speaker clustering via Pyannote models - **Speaker Diarization (Online + Offline)**: Speaker separation and identification across audio streams. Streaming pipeline for real-time processing and offline batch pipeline with advanced clustering.
- **Speaker Embedding Extraction**: Generate speaker embeddings for voice comparison and clustering, you can use this for speaker identification - **Speaker Embedding Extraction**: Generate speaker embeddings for voice comparison and clustering, you can use this for speaker identification
- **Voice Activity Detection (VAD)**: Voice activity detection with Silero models - **Voice Activity Detection (VAD)**: Voice activity detection with Silero models
- **Real-time Processing**: Designed for near real-time workloads but also works for offline processing - **Real-time Processing**: Designed for near real-time workloads but also works for offline processing
@@ -158,13 +158,53 @@ swift run fluidaudio transcribe audio.wav --model-version v2
## Speaker Diarization ## Speaker Diarization
**AMI Benchmark Results** (Single Distant Microphone) using a subset of the files:
- **DER: 17.7%** — Competitive with Powerset BCE 2023 (18.5%) ### Offline Speaker Diarization Pipeline
- **JER: 28.0%** — Outperforms EEND 2019 (25.3%) and x-vector clustering (28.7%)
- **RTF: 0.02x** — Real-time processing with 50x speedup
### Speaker Diarization Quick Start Pyannote Community-1 pipeline (powerset segmentation + WeSpeaker + VBx) for offline speaker diarization. Use this for most use cases, see Benchmarkds.md for benchmarks.
```swift
import FluidAudio
let config = OfflineDiarizerConfig()
let manager = OfflineDiarizerManager(config: config)
try await manager.prepareModels() // Downloads + compiles Core ML bundles if they are missing
let samples = try AudioConverter().resampleAudioFile(path: "meeting.wav")
let result = try await manager.process(audio: samples)
for segment in result.segments {
print("\(segment.speakerId) \(segment.startTimeSeconds)s → \(segment.endTimeSeconds)s")
}
```
For processing audio files, use the file-based API which automatically uses memory-mapped streaming for efficiency:
```swift
let url = URL(fileURLWithPath: "meeting.wav")
let result = try await manager.process(url)
for segment in result.segments {
print("\(segment.speakerId) \(segment.startTimeSeconds)s → \(segment.endTimeSeconds)s")
}
```
```bash
# Process a meeting with full VBx clustering
swift run fluidaudio process ~/FluidAudioDatasets/ami_official/sdm/ES2004a.Mix-Headset.wav \
--mode offline --threshold 0.6 --output es2004a_offline.json
# Run the AMI single-file benchmark with automatic downloads
swift run fluidaudio diarization-benchmark --mode offline --auto-download \
--single-file ES2004a --threshold 0.6 --output offline_results.json
```
`offline_results.json` contains DER/JER/RTFx along with timing breakdowns for segmentation, embedding extraction, and VBx clustering. CI now runs this workflow on every PR to ensure the offline models stay healthy and the Hugging Face assets remain accessible.
### Streaming/Online Speaker Diarization
Use this if you need to show speaker labels while the transcription is happening, in most use cases, offline should be more than enough.
```swift ```swift
import FluidAudio import FluidAudio
@@ -172,7 +212,7 @@ import FluidAudio
// Diarize an audio file // Diarize an audio file
Task { Task {
let models = try await DiarizerModels.downloadIfNeeded() let models = try await DiarizerModels.downloadIfNeeded()
let diarizer = DiarizerManager() // Uses optimal defaults (0.7 threshold = 17.7% DER) let diarizer = DiarizerManager()
diarizer.initialize(models: models) diarizer.initialize(models: models)
// Prepare 16 kHz mono samples (see: Audio Conversion) // Prepare 16 kHz mono samples (see: Audio Conversion)
@@ -193,6 +233,7 @@ swift run fluidaudio diarization-benchmark --single-file ES2004a \
--chunk-seconds 3 --overlap-seconds 2 --chunk-seconds 3 --overlap-seconds 2
``` ```
### CLI ### CLI
```bash ```bash
@@ -357,6 +398,12 @@ Build requires eSpeak NG headers/libs for the C API discoverable via pkg-config
- Dictionary and model assets are cached under `~/.cache/fluidaudio/Models/kokoro`. - Dictionary and model assets are cached under `~/.cache/fluidaudio/Models/kokoro`.
## Continuous Integration
- `tests.yml`: Default build matrix covering SwiftPM tests and an iOS archive smoke test.
- `diarizer-benchmark.yml`: Runs the streaming diarization benchmark on ES2004a for regression tracking.
- `offline-pipeline.yml`: Executes the VBx offline pipeline end-to-end (`fluidaudio diarization-benchmark --mode offline`) and fails if DER/JER drift beyond guardrails or if models fail to download. Use this workflow as a reference for provisioning model caches in your own CI.
## Everything Else ## Everything Else
### FAQs ### FAQs
@@ -0,0 +1,244 @@
#include "FastClusterWrapper.h"
#include <cmath>
#include <cstddef>
#include <exception>
#include <new>
#include <vector>
#ifndef fc_isnan
#define fc_isnan(X) ((X) != (X))
#endif
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpragma-messages"
#endif
#include "fastcluster_internal.hpp"
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
namespace {
struct CentroidDissimilarity {
const t_float *data;
const t_index dimension;
const t_index count;
std::vector<t_float> centroidStorage;
std::vector<t_index> members;
CentroidDissimilarity(const t_float *input, t_index sampleCount, t_index dim)
: data(input),
dimension(dim),
count(sampleCount),
centroidStorage(sampleCount > 1 ? static_cast<size_t>((sampleCount - 1) * dim) : 0u),
members(sampleCount > 0 ? static_cast<size_t>(2 * sampleCount - 1) : 0u, 0) {
for (t_index i = 0; i < count; ++i) {
members[static_cast<size_t>(i)] = 1;
}
}
template <bool checkNaN>
t_float sqeuclidean(const t_index i, const t_index j) const {
const t_float *pi = basePointer(i);
const t_float *pj = basePointer(j);
t_float sum = 0;
for (t_index k = 0; k < dimension; ++k) {
const t_float diff = pi[k] - pj[k];
sum += diff * diff;
}
if constexpr (checkNaN) {
#if HAVE_DIAGNOSTIC
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
if (fc_isnan(sum)) {
#if HAVE_DIAGNOSTIC
#pragma GCC diagnostic pop
#endif
throw nan_error();
}
}
return sum;
}
t_float sqeuclidean_extended(const t_index i, const t_index j) const {
const t_float *pi = extendedPointer(i);
const t_float *pj = extendedPointer(j);
t_float sum = 0;
for (t_index k = 0; k < dimension; ++k) {
const t_float diff = pi[k] - pj[k];
sum += diff * diff;
}
#if HAVE_DIAGNOSTIC
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
if (fc_isnan(sum)) {
throw nan_error();
}
#if HAVE_DIAGNOSTIC
#pragma GCC diagnostic pop
#endif
return sum;
}
void merge(const t_index i, const t_index j, const t_index newNode) {
const t_float *pi = extendedPointer(i);
const t_float *pj = extendedPointer(j);
t_float *pn = centroidPointer(newNode);
const t_float mi = static_cast<t_float>(members[static_cast<size_t>(i)]);
const t_float mj = static_cast<t_float>(members[static_cast<size_t>(j)]);
const t_float denom = mi + mj;
for (t_index k = 0; k < dimension; ++k) {
pn[k] = (pi[k] * mi + pj[k] * mj) / denom;
}
members[static_cast<size_t>(newNode)] = members[static_cast<size_t>(i)] + members[static_cast<size_t>(j)];
}
void merge_weighted(const t_index i, const t_index j, const t_index newNode) {
const t_float *pi = extendedPointer(i);
const t_float *pj = extendedPointer(j);
t_float *pn = centroidPointer(newNode);
for (t_index k = 0; k < dimension; ++k) {
pn[k] = static_cast<t_float>(0.5) * (pi[k] + pj[k]);
}
members[static_cast<size_t>(newNode)] = members[static_cast<size_t>(i)] + members[static_cast<size_t>(j)];
}
t_float ward(const t_index i, const t_index j) const {
return sqeuclidean<true>(i, j);
}
t_float ward_initial(const t_index i, const t_index j) const {
return sqeuclidean<true>(i, j);
}
static t_float ward_initial_conversion(const t_float value) {
return value * static_cast<t_float>(0.5);
}
t_float ward_extended(const t_index i, const t_index j) const {
return sqeuclidean_extended(i, j);
}
void postprocess(cluster_result &result) const {
result.sqrt();
}
private:
const t_float *basePointer(const t_index index) const {
return data + static_cast<size_t>(index) * static_cast<size_t>(dimension);
}
const t_float *extendedPointer(const t_index index) const {
if (index < count) {
return basePointer(index);
}
return centroidStorage.data() + static_cast<size_t>(index - count) * static_cast<size_t>(dimension);
}
t_float *centroidPointer(const t_index index) {
return centroidStorage.data() + static_cast<size_t>(index - count) * static_cast<size_t>(dimension);
}
};
class LinkageOutput {
public:
explicit LinkageOutput(t_float *buffer) : cursor(buffer) {}
void append(t_index node1, t_index node2, t_float distance, t_float size) {
if (node1 < node2) {
*(cursor++) = static_cast<t_float>(node1);
*(cursor++) = static_cast<t_float>(node2);
} else {
*(cursor++) = static_cast<t_float>(node2);
*(cursor++) = static_cast<t_float>(node1);
}
*(cursor++) = distance;
*(cursor++) = size;
}
private:
t_float *cursor;
};
template <bool sorted>
void generateSciPyDendrogram(t_float *Z, cluster_result &Z2, const t_index N) {
union_find nodes(sorted ? 0 : N);
if (!sorted) {
std::stable_sort(Z2[0], Z2[N - 1]);
}
LinkageOutput output(Z);
t_index node1;
t_index node2;
for (node const *entry = Z2[0]; entry != Z2[N - 1]; ++entry) {
if (sorted) {
node1 = entry->node1;
node2 = entry->node2;
} else {
node1 = nodes.Find(entry->node1);
node2 = nodes.Find(entry->node2);
nodes.Union(node1, node2);
}
output.append(node1, node2, entry->dist,
((node1 < N) ? 1 : Z_(node1 - N, 3)) + ((node2 < N) ? 1 : Z_(node2 - N, 3)));
}
}
} // namespace
fastcluster_wrapper_status fastcluster_compute_centroid_linkage(
const double *data,
size_t pointCount,
size_t dimension,
double *dendrogramOut,
size_t dendrogramLength
) {
if (data == nullptr || dendrogramOut == nullptr) {
return FASTCLUSTER_WRAPPER_INVALID_ARGUMENT;
}
if (pointCount == 0) {
return FASTCLUSTER_WRAPPER_SUCCESS;
}
if (dimension == 0) {
return FASTCLUSTER_WRAPPER_INVALID_ARGUMENT;
}
if (pointCount > static_cast<size_t>(MAX_INDEX) || dimension > static_cast<size_t>(MAX_INDEX)) {
return FASTCLUSTER_WRAPPER_INDEX_OVERFLOW;
}
const size_t requiredLength = (pointCount > 1) ? (pointCount - 1) * 4 : 0;
if (dendrogramLength < requiredLength) {
return FASTCLUSTER_WRAPPER_OUTPUT_TOO_SMALL;
}
if (pointCount == 1) {
return FASTCLUSTER_WRAPPER_SUCCESS;
}
try {
const t_index N = static_cast<t_index>(pointCount);
const t_index dim = static_cast<t_index>(dimension);
CentroidDissimilarity dist(data, N, dim);
cluster_result result(N - 1);
generic_linkage_vector_alternative<METHOD_VECTOR_CENTROID>(N, dist, result);
dist.postprocess(result);
generateSciPyDendrogram<true>(dendrogramOut, result, N);
return FASTCLUSTER_WRAPPER_SUCCESS;
} catch (const std::bad_alloc &) {
return FASTCLUSTER_WRAPPER_ALLOCATION_FAILURE;
} catch (const nan_error &) {
return FASTCLUSTER_WRAPPER_RUNTIME_ERROR;
} catch (const std::exception &) {
return FASTCLUSTER_WRAPPER_RUNTIME_ERROR;
} catch (...) {
return FASTCLUSTER_WRAPPER_UNKNOWN_ERROR;
}
}
+48
View File
@@ -0,0 +1,48 @@
# FastCluster Wrapper
This directory contains a C wrapper around the [fastcluster](https://github.com/fastcluster/fastcluster) library, specifically exposing centroid linkage hierarchical clustering for Swift.
## Purpose
The FastCluster wrapper is required for accurate reimplementation of the **pyannote community-1 speaker diarization pipeline** in Swift. The pyannote pipeline uses agglomerative hierarchical clustering with centroid linkage to cluster speaker embeddings, and this wrapper provides an efficient C++ implementation via a C interface accessible from Swift.
## What's Included
- **`FastClusterWrapper.cpp`**: C wrapper implementation
- **`fastcluster_internal.hpp`**: Internal fastcluster algorithms (from upstream fastcluster)
- **`include/FastClusterWrapper.h`**: C API header
- **`include/module.modulemap`**: Swift module bridge
## Functionality
### `fastcluster_compute_centroid_linkage()`
```c
fastcluster_wrapper_status fastcluster_compute_centroid_linkage(
const double *data, // Feature vectors (row-major layout)
size_t pointCount, // Number of vectors
size_t dimension, // Feature dimension
double *dendrogramOut, // Output dendrogram (SciPy format)
size_t dendrogramLength // Output buffer size
);
```
Computes agglomerative hierarchical clustering using centroid linkage on the input feature vectors. Returns a dendrogram in SciPy format (4 columns: left node, right node, distance, sample count).
## Integration
Used by `Sources/FluidAudio/Diarizer/Offline/AHCClustering.swift` to perform speaker embedding clustering, which is a core component of the diarization pipeline.
## Source
- **Original Repository**: https://github.com/fastcluster/fastcluster
- **Algorithm**: Centroid linkage hierarchical clustering
- **Reference**: Based on algorithms by Daniel Müllner and Google Inc.
## License
fastcluster is licensed under the BSD 2-Clause License. See `ThirdPartyLicenses/fastcluster-LICENSE.md` for details.
Copyright:
- Until package version 1.1.23: © 2011 Daniel Müllner
- All changes from version 1.1.24 on: © Google Inc.
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,46 @@
#ifndef FASTCLUSTER_WRAPPER_H
#define FASTCLUSTER_WRAPPER_H
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
#endif
// Error codes returned by fastcluster wrapper.
typedef enum {
FASTCLUSTER_WRAPPER_SUCCESS = 0,
FASTCLUSTER_WRAPPER_INVALID_ARGUMENT = 1,
FASTCLUSTER_WRAPPER_INDEX_OVERFLOW = 2,
FASTCLUSTER_WRAPPER_OUTPUT_TOO_SMALL = 3,
FASTCLUSTER_WRAPPER_ALLOCATION_FAILURE = 4,
FASTCLUSTER_WRAPPER_RUNTIME_ERROR = 5,
FASTCLUSTER_WRAPPER_UNKNOWN_ERROR = 255
} fastcluster_wrapper_status;
/// Compute centroid linkage dendrogram for the provided feature matrix.
///
/// - Parameters:
/// - data: Pointer to `pointCount * dimension` doubles laid out row-major.
/// - pointCount: Number of vectors (>= 1).
/// - dimension: Feature dimension (> 0).
/// - dendrogramOut: Output buffer receiving `(pointCount - 1) * 4` doubles in SciPy
/// linkage format (columns: left, right, distance, sample_count).
/// - dendrogramLength: Length of `dendrogramOut` in elements.
///
/// - Returns:
/// - `FASTCLUSTER_WRAPPER_SUCCESS` on success.
/// - One of the error codes above otherwise.
fastcluster_wrapper_status fastcluster_compute_centroid_linkage(
const double *data,
size_t pointCount,
size_t dimension,
double *dendrogramOut,
size_t dendrogramLength
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // FASTCLUSTER_WRAPPER_H
@@ -0,0 +1,4 @@
module FastClusterWrapper {
header "FastClusterWrapper.h"
export *
}
@@ -13,3 +13,11 @@ func makeBlasIndex(_ value: Int, label: String) throws -> BlasIndex {
} }
return cast return cast
} }
@inline(__always)
func makeBlasIndexOrFatal(_ value: Int, label: String) -> BlasIndex {
guard let cast = BlasIndex(exactly: value) else {
preconditionFailure("\(label) exceeds supported BLAS index range (\(value))")
}
return cast
}
@@ -1,3 +1,4 @@
import Accelerate
import CoreML import CoreML
import Foundation import Foundation
import OSLog import OSLog
@@ -7,6 +8,7 @@ public final class DiarizerManager {
internal let logger = AppLogger(category: "Diarizer") internal let logger = AppLogger(category: "Diarizer")
internal let config: DiarizerConfig internal let config: DiarizerConfig
private var models: DiarizerModels? private var models: DiarizerModels?
private var chunkBuffer: [Float] = []
/// Public getter for segmentation model (for streaming) /// Public getter for segmentation model (for streaming)
public var segmentationModel: MLModel? { public var segmentationModel: MLModel? {
@@ -39,10 +41,6 @@ public final class DiarizerManager {
models != nil models != nil
} }
public var initializationTimings: (downloadTime: TimeInterval, compilationTime: TimeInterval) {
models.map { ($0.downloadDuration, $0.compilationDuration) } ?? (0, 0)
}
public func initialize(models: consuming DiarizerModels) { public func initialize(models: consuming DiarizerModels) {
logger.info("Initializing diarization system") logger.info("Initializing diarization system")
@@ -162,7 +160,6 @@ public final class DiarizerManager {
if config.debugMode { if config.debugMode {
let timings = PipelineTimings( let timings = PipelineTimings(
modelDownloadSeconds: models.downloadDuration,
modelCompilationSeconds: models.compilationDuration, modelCompilationSeconds: models.compilationDuration,
audioLoadingSeconds: 0, audioLoadingSeconds: 0,
segmentationSeconds: segmentationTime, segmentationSeconds: segmentationTime,
@@ -213,21 +210,44 @@ public final class DiarizerManager {
let chunkSize = sampleRate * 10 let chunkSize = sampleRate * 10
let chunkCount = chunk.distance(from: chunk.startIndex, to: chunk.endIndex) let chunkCount = chunk.distance(from: chunk.startIndex, to: chunk.endIndex)
let copyCount = min(chunkCount, chunkSize)
let paddedChunk: ArraySlice<Float> if chunkBuffer.count != chunkSize {
if chunkCount < chunkSize { chunkBuffer = [Float](repeating: 0.0, count: chunkSize)
var padded = Array(repeating: 0.0 as Float, count: chunkSize)
for (idx, element) in chunk.enumerated() {
padded[idx] = element
}
paddedChunk = padded[...]
} else if let slice = chunk as? ArraySlice<Float> {
paddedChunk = slice
} else {
// Convert to ArraySlice for other collection types
paddedChunk = Array(chunk)[...]
} }
chunkBuffer.withUnsafeMutableBufferPointer { buffer in
guard let baseAddress = buffer.baseAddress else { return }
vDSP_vclr(baseAddress, 1, vDSP_Length(chunkSize))
guard copyCount > 0 else { return }
let copied =
chunk.withContiguousStorageIfAvailable { storage -> Bool in
storage.withUnsafeBufferPointer { src in
vDSP_mmov(
src.baseAddress!,
baseAddress,
vDSP_Length(copyCount),
vDSP_Length(1),
vDSP_Length(1),
vDSP_Length(chunkSize)
)
}
return true
} ?? false
if !copied {
var index = chunk.startIndex
for i in 0..<copyCount {
baseAddress.advanced(by: i).pointee = chunk[index]
index = chunk.index(after: index)
}
}
}
let paddedChunk: ArraySlice<Float> = chunkBuffer[0..<chunkSize]
let (binarizedSegments, _) = try segmentationProcessor.getSegments( let (binarizedSegments, _) = try segmentationProcessor.getSegments(
audioChunk: paddedChunk, audioChunk: paddedChunk,
segmentationModel: models.segmentationModel segmentationModel: models.segmentationModel
@@ -14,16 +14,14 @@ public struct DiarizerModels: Sendable {
public let segmentationModel: CoreMLDiarizer.SegmentationModel public let segmentationModel: CoreMLDiarizer.SegmentationModel
public let embeddingModel: CoreMLDiarizer.EmbeddingModel public let embeddingModel: CoreMLDiarizer.EmbeddingModel
public let downloadDuration: TimeInterval
public let compilationDuration: TimeInterval public let compilationDuration: TimeInterval
init( init(
segmentation: MLModel, embedding: MLModel, downloadDuration: TimeInterval = 0, segmentation: MLModel, embedding: MLModel,
compilationDuration: TimeInterval = 0 compilationDuration: TimeInterval = 0
) { ) {
self.segmentationModel = segmentation self.segmentationModel = segmentation
self.embeddingModel = embedding self.embeddingModel = embedding
self.downloadDuration = downloadDuration
self.compilationDuration = compilationDuration self.compilationDuration = compilationDuration
} }
} }
@@ -73,14 +71,11 @@ extension DiarizerModels {
let endTime = Date() let endTime = Date()
let totalDuration = endTime.timeIntervalSince(startTime) let totalDuration = endTime.timeIntervalSince(startTime)
let downloadDuration: TimeInterval = 0 // Models are typically cached
let compilationDuration = totalDuration
return DiarizerModels( return DiarizerModels(
segmentation: segmentationModel, segmentation: segmentationModel,
embedding: embeddingModel, embedding: embeddingModel,
downloadDuration: downloadDuration, compilationDuration: totalDuration)
compilationDuration: compilationDuration)
} }
public static func load( public static func load(
@@ -98,7 +93,7 @@ extension DiarizerModels {
return try await download(to: directory, configuration: configuration) return try await download(to: directory, configuration: configuration)
} }
static func defaultModelsDirectory() -> URL { public static func defaultModelsDirectory() -> URL {
let applicationSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! let applicationSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
return return
applicationSupport applicationSupport
@@ -145,7 +140,7 @@ extension DiarizerModels {
let endTime = Date() let endTime = Date()
let loadDuration = endTime.timeIntervalSince(startTime) let loadDuration = endTime.timeIntervalSince(startTime)
return DiarizerModels( return DiarizerModels(
segmentation: segmentationModel, embedding: embeddingModel, downloadDuration: 0, segmentation: segmentationModel, embedding: embeddingModel,
compilationDuration: loadDuration) compilationDuration: loadDuration)
} }
} }
@@ -58,7 +58,6 @@ public struct DiarizerConfig: Sendable {
} }
public struct PipelineTimings: Sendable, Codable { public struct PipelineTimings: Sendable, Codable {
public let modelDownloadSeconds: TimeInterval
public let modelCompilationSeconds: TimeInterval public let modelCompilationSeconds: TimeInterval
public let audioLoadingSeconds: TimeInterval public let audioLoadingSeconds: TimeInterval
public let segmentationSeconds: TimeInterval public let segmentationSeconds: TimeInterval
@@ -69,7 +68,6 @@ public struct PipelineTimings: Sendable, Codable {
public let totalProcessingSeconds: TimeInterval public let totalProcessingSeconds: TimeInterval
public init( public init(
modelDownloadSeconds: TimeInterval = 0,
modelCompilationSeconds: TimeInterval = 0, modelCompilationSeconds: TimeInterval = 0,
audioLoadingSeconds: TimeInterval = 0, audioLoadingSeconds: TimeInterval = 0,
segmentationSeconds: TimeInterval = 0, segmentationSeconds: TimeInterval = 0,
@@ -77,7 +75,6 @@ public struct PipelineTimings: Sendable, Codable {
speakerClusteringSeconds: TimeInterval = 0, speakerClusteringSeconds: TimeInterval = 0,
postProcessingSeconds: TimeInterval = 0 postProcessingSeconds: TimeInterval = 0
) { ) {
self.modelDownloadSeconds = modelDownloadSeconds
self.modelCompilationSeconds = modelCompilationSeconds self.modelCompilationSeconds = modelCompilationSeconds
self.audioLoadingSeconds = audioLoadingSeconds self.audioLoadingSeconds = audioLoadingSeconds
self.segmentationSeconds = segmentationSeconds self.segmentationSeconds = segmentationSeconds
@@ -87,7 +84,7 @@ public struct PipelineTimings: Sendable, Codable {
self.totalInferenceSeconds = self.totalInferenceSeconds =
segmentationSeconds + embeddingExtractionSeconds + speakerClusteringSeconds segmentationSeconds + embeddingExtractionSeconds + speakerClusteringSeconds
self.totalProcessingSeconds = self.totalProcessingSeconds =
modelDownloadSeconds + modelCompilationSeconds + audioLoadingSeconds modelCompilationSeconds + audioLoadingSeconds
+ segmentationSeconds + embeddingExtractionSeconds + speakerClusteringSeconds + segmentationSeconds + embeddingExtractionSeconds + speakerClusteringSeconds
+ postProcessingSeconds + postProcessingSeconds
} }
@@ -98,7 +95,6 @@ public struct PipelineTimings: Sendable, Codable {
} }
return [ return [
"Model Download": (modelDownloadSeconds / totalProcessingSeconds) * 100,
"Model Compilation": (modelCompilationSeconds / totalProcessingSeconds) * 100, "Model Compilation": (modelCompilationSeconds / totalProcessingSeconds) * 100,
"Audio Loading": (audioLoadingSeconds / totalProcessingSeconds) * 100, "Audio Loading": (audioLoadingSeconds / totalProcessingSeconds) * 100,
"Segmentation": (segmentationSeconds / totalProcessingSeconds) * 100, "Segmentation": (segmentationSeconds / totalProcessingSeconds) * 100,
@@ -110,7 +106,6 @@ public struct PipelineTimings: Sendable, Codable {
public var bottleneckStage: String { public var bottleneckStage: String {
let stages = [ let stages = [
("Model Download", modelDownloadSeconds),
("Model Compilation", modelCompilationSeconds), ("Model Compilation", modelCompilationSeconds),
("Audio Loading", audioLoadingSeconds), ("Audio Loading", audioLoadingSeconds),
("Segmentation", segmentationSeconds), ("Segmentation", segmentationSeconds),
@@ -126,10 +121,10 @@ public struct PipelineTimings: Sendable, Codable {
public struct DiarizationResult: Sendable { public struct DiarizationResult: Sendable {
public let segments: [TimedSpeakerSegment] public let segments: [TimedSpeakerSegment]
/// Speaker database with embeddings (only populated when debugMode is enabled) /// Speaker database with embeddings (populated by offline pipelines for downstream use)
public let speakerDatabase: [String: [Float]]? public let speakerDatabase: [String: [Float]]?
/// Performance timings (only populated when debugMode is enabled) /// Performance timings collected during diarization
public let timings: PipelineTimings? public let timings: PipelineTimings?
public init( public init(
@@ -0,0 +1,205 @@
import Accelerate
import Foundation
import OSLog
import os.signpost
#if canImport(FastClusterWrapper)
import FastClusterWrapper
#elseif canImport(FluidAudio_FastClusterWrapper)
import FluidAudio_FastClusterWrapper
#endif
struct AHCClustering {
private let logger = AppLogger(category: "OfflineAHC")
private let signposter = OSSignposter(
subsystem: "com.fluidaudio.diarization",
category: .pointsOfInterest
)
// MARK: - Agglomerative Hierarchical Clustering
func cluster(
embeddingFeatures: [[Double]],
threshold: Double
) -> [Int] {
let count = embeddingFeatures.count
guard count > 0 else { return [] }
guard let dimension = embeddingFeatures.first?.count, dimension > 0 else {
return Array(repeating: 0, count: count)
}
if count == 1 {
return [0]
}
let ahcState = signposter.beginInterval("Agglomerative Hierarchical Clustering")
let normalized = normalizeFeatures(embeddingFeatures, dimension: dimension)
let dendrogramLength = (count - 1) * 4
var dendrogram = [Double](repeating: 0, count: dendrogramLength)
// MARK: - Fastcluster FFI Boundary
let status = normalized.withUnsafeBufferPointer { normalizedPointer in
dendrogram.withUnsafeMutableBufferPointer { dendrogramPointer in
fastcluster_compute_centroid_linkage(
normalizedPointer.baseAddress,
count,
dimension,
dendrogramPointer.baseAddress,
dendrogramLength
)
}
}
guard status == FASTCLUSTER_WRAPPER_SUCCESS else {
logger.error("fastcluster failed with status \(status.rawValue)")
return Array(0..<count)
}
let distanceThreshold = convertThresholdToDistance(threshold)
let assignments = assignmentsFromDendrogram(
dendrogram,
count: count,
distanceThreshold: distanceThreshold
)
let result = remapClusterIds(assignments)
signposter.endInterval("Agglomerative Hierarchical Clustering", ahcState)
return result
}
// MARK: - L2 Feature Normalization
private func normalizeFeatures(_ features: [[Double]], dimension: Int) -> [Double] {
var normalized = [Double](repeating: 0, count: features.count * dimension)
for (rowIndex, vector) in features.enumerated() {
precondition(vector.count == dimension, "All feature vectors must share the same dimension")
var norm: Double = 0
vector.withUnsafeBufferPointer { pointer in
vDSP_dotprD(
pointer.baseAddress!,
1,
pointer.baseAddress!,
1,
&norm,
vDSP_Length(dimension)
)
}
let scale = norm > 0 ? 1.0 / sqrt(norm) : 0
var mutableScale = scale
vector.withUnsafeBufferPointer { source in
normalized.withUnsafeMutableBufferPointer { destination in
vDSP_vsmulD(
source.baseAddress!,
1,
&mutableScale,
destination.baseAddress!.advanced(by: rowIndex * dimension),
1,
vDSP_Length(dimension)
)
}
}
}
return normalized
}
// MARK: - Similarity-to-Distance Conversion
private func convertThresholdToDistance(_ similarity: Double) -> Double {
guard !similarity.isNaN else { return Double.infinity }
if similarity < -1.0 || similarity > 1.0 {
logger.debug("Clustering threshold \(similarity) outside cosine range; clamping to [-1, 1]")
}
let clamped = max(-1.0, min(1.0, similarity))
return sqrt(max(0, 2.0 - 2.0 * clamped))
}
// MARK: - Dendrogram Parsing & Threshold-Based Cluster Assignment
private func assignmentsFromDendrogram(
_ dendrogram: [Double],
count: Int,
distanceThreshold: Double
) -> [Int] {
guard count > 0 else { return [] }
if count == 1 {
return [0]
}
let totalNodes = count * 2 - 1
var leftChild = [Int](repeating: -1, count: totalNodes)
var rightChild = [Int](repeating: -1, count: totalNodes)
var nodeDistance = [Double](repeating: 0, count: totalNodes)
for mergeIndex in 0..<(count - 1) {
let base = mergeIndex * 4
let left = Int(dendrogram[base])
let right = Int(dendrogram[base + 1])
let dist = dendrogram[base + 2]
let newNode = count + mergeIndex
leftChild[newNode] = left
rightChild[newNode] = right
nodeDistance[newNode] = dist
}
let root = totalNodes - 1
var assignments = [Int](repeating: -1, count: count)
var stack = [root]
var nextLabel = 0
while let node = stack.popLast() {
if node < 0 {
continue
}
if node < count {
if assignments[node] == -1 {
assignments[node] = nextLabel
nextLabel += 1
}
continue
}
let distance = nodeDistance[node]
if distance <= distanceThreshold {
let label = nextLabel
nextLabel += 1
var queue = [node]
while let current = queue.popLast() {
if current < count {
assignments[current] = label
} else {
let left = leftChild[current]
let right = rightChild[current]
if left >= 0 { queue.append(left) }
if right >= 0 { queue.append(right) }
}
}
} else {
let left = leftChild[node]
let right = rightChild[node]
if left >= 0 { stack.append(left) }
if right >= 0 { stack.append(right) }
}
}
for index in 0..<assignments.count where assignments[index] == -1 {
assignments[index] = nextLabel
nextLabel += 1
}
return assignments
}
// MARK: - Cluster ID Remapping
private func remapClusterIds(_ assignments: [Int]) -> [Int] {
var mapping: [Int: Int] = [:]
var nextId = 0
return assignments.map { original in
if mapping[original] == nil {
mapping[original] = nextId
nextId += 1
}
return mapping[original]!
}
}
}
@@ -0,0 +1,736 @@
import Accelerate
import CoreML
import Foundation
import OSLog
@available(macOS 14.0, iOS 17.0, *)
public final class OfflineDiarizerManager {
private let logger = AppLogger(category: "OfflineDiarizer")
private let config: OfflineDiarizerConfig
private var models: OfflineDiarizerModels?
public init(config: OfflineDiarizerConfig = .default) {
self.config = config
}
public func initialize(models: OfflineDiarizerModels) {
self.models = models
logger.info("Offline diarizer models initialized")
}
/// Ensure offline diarizer models are available, downloading and compiling them when needed.
/// - Parameters:
/// - directory: Custom cache directory. Defaults to `OfflineDiarizerModels.defaultModelsDirectory()`.
/// - configuration: Optional CoreML configuration to use during compilation.
/// - forceRedownload: When `true`, the cached repo is deleted before attempting to load.
public func prepareModels(
directory: URL? = nil,
configuration: MLModelConfiguration? = nil,
forceRedownload: Bool = false
) async throws {
if !forceRedownload, models != nil {
logger.debug("Offline diarizer models already prepared; skipping load")
return
}
let targetDirectory =
directory?.standardizedFileURL
?? OfflineDiarizerModels.defaultModelsDirectory().standardizedFileURL
if forceRedownload {
do {
try purgeDiarizerRepo(at: targetDirectory)
} catch {
logger.warning(
"Failed to purge diarizer cache during forced reload: \(error.localizedDescription)")
}
}
do {
let loadedModels = try await OfflineDiarizerModels.load(
from: targetDirectory,
configuration: configuration
)
initialize(models: loadedModels)
await prewarmModelsIfNeeded(loadedModels)
logger.info("Offline diarizer models loaded from \(targetDirectory.path)")
} catch {
logger.error(
"Initial offline diarizer model load failed: \(error.localizedDescription)")
logger.info("Attempting fallback download and compilation")
do {
try purgeDiarizerRepo(at: targetDirectory)
} catch {
logger.warning(
"Failed to remove cached diarizer repo before fallback: \(error.localizedDescription)")
}
do {
let reloadedModels = try await OfflineDiarizerModels.load(
from: targetDirectory,
configuration: configuration
)
initialize(models: reloadedModels)
await prewarmModelsIfNeeded(reloadedModels)
let durationText = String(format: "%.2f", reloadedModels.compilationDuration)
logger.info(
"Fallback download + compile completed in \(durationText)s at \(targetDirectory.path)")
} catch {
logger.error(
"Fallback offline diarizer model load failed: \(error.localizedDescription)")
throw error
}
}
}
public func process(audio: [Float]) async throws -> DiarizationResult {
try await process(
audioSource: ArrayAudioSampleSource(samples: audio),
audioLoadingSeconds: 0
)
}
/// Process audio from a file URL using memory-mapped streaming for efficiency.
/// Automatically converts the audio to the target sample rate and processes in chunks.
/// - Parameter url: Path to the audio file
/// - Returns: Diarization result with speaker segments
public func process(_ url: URL) async throws -> DiarizationResult {
let factory = StreamingAudioSourceFactory()
let (source, loadDuration) = try factory.makeDiskBackedSource(
from: url,
targetSampleRate: config.segmentation.sampleRate
)
defer { source.cleanup() }
return try await process(
audioSource: source,
audioLoadingSeconds: loadDuration
)
}
public func process(
audioSource: StreamingAudioSampleSource,
audioLoadingSeconds: TimeInterval
) async throws -> DiarizationResult {
try config.validate()
if models == nil {
try await prepareModels()
}
guard let models else {
throw OfflineDiarizationError.modelNotLoaded("offline-diarizer")
}
let totalStart = Date()
let streamPair = AsyncThrowingStream<SegmentationChunk, Error>.makeStream()
let chunkStream = streamPair.stream
let chunkContinuation = streamPair.continuation
let segmentationTask = Task(priority: .userInitiated) { () throws -> (SegmentationOutput, TimeInterval) in
let processor = OfflineSegmentationProcessor()
let start = Date()
do {
let segmentation = try await processor.process(
audioSource: audioSource,
segmentationModel: models.segmentationModel,
config: config,
chunkHandler: { chunk in
switch chunkContinuation.yield(chunk) {
case .enqueued, .dropped:
return .continue
case .terminated:
return .stop
@unknown default:
return .stop
}
}
)
chunkContinuation.finish()
return (segmentation, Date().timeIntervalSince(start))
} catch {
chunkContinuation.finish(throwing: error)
throw error
}
}
let embeddingTask = Task(priority: .userInitiated) { () throws -> ([TimedEmbedding], TimeInterval) in
let extractor = OfflineEmbeddingExtractor(
fbankModel: models.fbankModel,
embeddingModel: models.embeddingModel,
pldaTransform: PLDATransform(pldaRhoModel: models.pldaRhoModel, psi: models.pldaPsi),
config: config
)
let start = Date()
let embeddings = try await extractor.extractEmbeddings(
audioSource: audioSource,
segmentationStream: chunkStream
)
return (embeddings, Date().timeIntervalSince(start))
}
let segmentationResult: (SegmentationOutput, TimeInterval)
let embeddingResult: ([TimedEmbedding], TimeInterval)
do {
async let awaitedSegmentation = segmentationTask.value
async let awaitedEmbeddings = embeddingTask.value
segmentationResult = try await awaitedSegmentation
embeddingResult = try await awaitedEmbeddings
} catch {
segmentationTask.cancel()
embeddingTask.cancel()
chunkContinuation.finish(throwing: error)
throw error
}
let (segmentation, segmentationTime) = segmentationResult
logger.debug("Segmentation completed in \(segmentationTime)s (async)")
let (timedEmbeddings, embeddingTime) = embeddingResult
logger.debug("Embedding extraction produced \(timedEmbeddings.count) vectors in \(embeddingTime)s (async)")
let pldaTransform = PLDATransform(pldaRhoModel: models.pldaRhoModel, psi: models.pldaPsi)
guard !timedEmbeddings.isEmpty else {
throw OfflineDiarizationError.noSpeechDetected
}
let embeddingFeatures = timedEmbeddings.map { $0.embedding256.map { Double($0) } }
let rhoFeatures = timedEmbeddings.map { $0.rho128 }
let clusteringStart = Date()
let trainingIndices = selectTrainingEmbeddings(
timedEmbeddings: timedEmbeddings
)
let trainingEmbeddings = trainingIndices.map { embeddingFeatures[$0] }
let trainingRho = trainingIndices.map { rhoFeatures[$0] }
logger.debug(
"Clustering will use \(trainingEmbeddings.count)/\(timedEmbeddings.count) embeddings (NaN filtered)"
)
let initialClusters: [Int]
if trainingEmbeddings.count >= 2 {
initialClusters = AHCClustering().cluster(
embeddingFeatures: trainingEmbeddings,
threshold: config.clusteringThreshold
)
} else {
initialClusters = Array(repeating: 0, count: trainingEmbeddings.count)
}
let vbxOutput: VBxOutput
if !trainingRho.isEmpty, !initialClusters.isEmpty {
vbxOutput = VBxClustering(config: config, pldaTransform: pldaTransform).refine(
rhoFeatures: trainingRho,
initialClusters: initialClusters
)
} else {
vbxOutput = VBxOutput(
gamma: [],
pi: [],
hardClusters: [initialClusters],
centroids: [],
numClusters: initialClusters.max().map { $0 + 1 } ?? 0,
elbos: []
)
}
let centroidComputation = computeCentroids(
trainingEmbeddings: trainingEmbeddings,
vbxOutput: vbxOutput,
initialClusters: initialClusters
)
var centroids = centroidComputation.centroids
if centroids.isEmpty {
centroids = computeFallbackCentroids(from: embeddingFeatures)
}
let assignments = assignEmbeddings(
embeddingFeatures: embeddingFeatures,
centroids: centroids
)
let chunkAssignments = buildChunkAssignments(
segmentation: segmentation,
timedEmbeddings: timedEmbeddings,
assignments: assignments,
clusterCount: centroids.count
)
let clusteringTime = Date().timeIntervalSince(clusteringStart)
if !assignments.isEmpty {
let histogram = assignments.reduce(into: [:]) { partialResult, cluster in
partialResult[cluster, default: 0] += 1
}
logger.debug(
"Clustering completed in \(clusteringTime)s with \(centroids.count) centroids (assignment histogram: \(histogram))"
)
} else {
logger.debug("Clustering completed in \(clusteringTime)s with no assignments")
}
let reconstruction = OfflineReconstruction(config: config)
let segments = reconstruction.buildSegments(
segmentation: segmentation,
hardClusters: chunkAssignments,
centroids: centroids
)
let speakerDatabase = reconstruction.buildSpeakerDatabase(segments: segments)
if let exportPath = config.embeddingExportPath {
try exportEmbeddings(
embeddings: timedEmbeddings,
assignments: assignments,
path: exportPath
)
}
let totalProcessing = Date().timeIntervalSince(totalStart)
let timings = PipelineTimings(
modelCompilationSeconds: models.compilationDuration,
audioLoadingSeconds: audioLoadingSeconds,
segmentationSeconds: segmentationTime,
embeddingExtractionSeconds: embeddingTime,
speakerClusteringSeconds: clusteringTime,
postProcessingSeconds: max(0, totalProcessing - segmentationTime - embeddingTime - clusteringTime)
)
return DiarizationResult(
segments: segments,
speakerDatabase: speakerDatabase,
timings: timings
)
}
private func purgeDiarizerRepo(at baseDirectory: URL) throws {
let repoDirectory = baseDirectory.appendingPathComponent(
Repo.diarizer.folderName,
isDirectory: true
)
if FileManager.default.fileExists(atPath: repoDirectory.path) {
try FileManager.default.removeItem(at: repoDirectory)
}
}
private func prewarmModelsIfNeeded(_ models: OfflineDiarizerModels) async {
do {
let start = Date()
try prewarmSegmentationModel(models.segmentationModel)
let elapsed = Date().timeIntervalSince(start)
let elapsedString = String(format: "%.3f", elapsed)
logger.debug("Segmentation model prewarmed in \(elapsedString)s")
} catch {
logger.debug("Segmentation prewarm skipped: \(error.localizedDescription)")
}
do {
let start = Date()
try await prewarmEmbeddingStack(models: models)
let elapsed = Date().timeIntervalSince(start)
let elapsedString = String(format: "%.3f", elapsed)
logger.debug("Embedding stack prewarmed in \(elapsedString)s")
} catch {
logger.debug("Embedding prewarm skipped: \(error.localizedDescription)")
}
}
private func prewarmSegmentationModel(_ model: MLModel) throws {
let shape: [NSNumber] = [
1,
1,
NSNumber(value: config.samplesPerWindow),
]
let array = try MLMultiArray(shape: shape, dataType: .float32)
let pointer = array.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(pointer, 1, vDSP_Length(array.count))
let provider = ZeroCopyDiarizerFeatureProvider(
features: ["audio": MLFeatureValue(multiArray: array)]
)
let options = MLPredictionOptions()
array.prefetchToNeuralEngine()
_ = try model.prediction(from: provider, options: options)
}
private func prewarmEmbeddingStack(models: OfflineDiarizerModels) async throws {
let extractor = OfflineEmbeddingExtractor(
fbankModel: models.fbankModel,
embeddingModel: models.embeddingModel,
pldaTransform: PLDATransform(pldaRhoModel: models.pldaRhoModel, psi: models.pldaPsi),
config: config
)
let dummyAudio = [Float](repeating: 0, count: config.samplesPerWindow)
let dummySegmentation = SegmentationOutput(
logProbs: [[[0]]],
speakerWeights: [[[1.0]]],
numChunks: 1,
numFrames: 1,
numSpeakers: 1,
chunkOffsets: [0],
frameDuration: max(1e-3, config.windowDuration)
)
_ = try await extractor.extractEmbeddings(
audio: dummyAudio,
segmentation: dummySegmentation
)
}
private func selectTrainingEmbeddings(
timedEmbeddings: [TimedEmbedding]
) -> [Int] {
var selected: [Int] = []
selected.reserveCapacity(timedEmbeddings.count)
for (index, embedding) in timedEmbeddings.enumerated() {
let hasNaN = embedding.embedding256.contains { $0.isNaN || $0.isInfinite }
if hasNaN {
continue
}
selected.append(index)
}
if selected.isEmpty {
return Array(timedEmbeddings.indices)
}
return selected
}
private func computeCentroids(
trainingEmbeddings: [[Double]],
vbxOutput: VBxOutput,
initialClusters: [Int]
) -> (centroids: [[Double]], mapping: [Int: Int]) {
guard let dimension = trainingEmbeddings.first?.count else {
return ([], [:])
}
let epsilon = 1e-7
let gamma = vbxOutput.gamma
let pi = vbxOutput.pi
if !gamma.isEmpty, !pi.isEmpty {
let activeSpeakers = pi.enumerated().filter { $0.element > epsilon }
if !activeSpeakers.isEmpty {
var centroids: [[Double]] = []
centroids.reserveCapacity(activeSpeakers.count)
var mapping: [Int: Int] = [:]
for (index, speaker) in activeSpeakers.enumerated() {
let speakerIdx = speaker.offset
mapping[speakerIdx] = index
var numerator = [Double](repeating: 0, count: dimension)
var denominator = 0.0
let dimensionIndex = makeBlasIndexOrFatal(dimension, label: "centroid dimension")
let unitStride = BlasIndex(1)
let frameLimit = min(gamma.count, trainingEmbeddings.count)
for frameIdx in 0..<frameLimit {
let weight = gamma[frameIdx][speakerIdx]
guard weight > 0 else { continue }
denominator += weight
let embedding = trainingEmbeddings[frameIdx]
precondition(
embedding.count == dimension,
"Jagged training embeddings are not supported"
)
embedding.withUnsafeBufferPointer { sourcePointer in
numerator.withUnsafeMutableBufferPointer { destinationPointer in
guard
let sourceBase = sourcePointer.baseAddress,
let destinationBase = destinationPointer.baseAddress
else { return }
cblas_daxpy(
dimensionIndex,
weight,
sourceBase,
unitStride,
destinationBase,
unitStride
)
}
}
}
if denominator > 0 {
centroids.append(numerator.map { $0 / denominator })
} else {
centroids.append([Double](repeating: 0, count: dimension))
}
}
return (centroids, mapping)
}
}
return computeCentroidsFromClusters(
embeddings: trainingEmbeddings,
clusters: initialClusters
)
}
private func computeCentroidsFromClusters(
embeddings: [[Double]],
clusters: [Int]
) -> (centroids: [[Double]], mapping: [Int: Int]) {
guard !embeddings.isEmpty, embeddings.count == clusters.count else {
return ([], [:])
}
var grouped: [Int: (sum: [Double], count: Int)] = [:]
for (embedding, cluster) in zip(embeddings, clusters) {
if grouped[cluster] == nil {
grouped[cluster] = (sum: [Double](repeating: 0, count: embedding.count), count: 0)
}
precondition(
embedding.count == grouped[cluster]!.sum.count,
"Jagged training embeddings are not supported"
)
var entry = grouped[cluster]!
let countIndex = makeBlasIndexOrFatal(embedding.count, label: "centroid accumulation length")
let unitStride = BlasIndex(1)
embedding.withUnsafeBufferPointer { sourcePointer in
entry.sum.withUnsafeMutableBufferPointer { destinationPointer in
guard
let sourceBase = sourcePointer.baseAddress,
let destinationBase = destinationPointer.baseAddress
else { return }
cblas_daxpy(
countIndex,
1.0,
sourceBase,
unitStride,
destinationBase,
unitStride
)
}
}
entry.count += 1
grouped[cluster] = entry
}
let sortedKeys = grouped.keys.sorted()
var centroids: [[Double]] = []
var mapping: [Int: Int] = [:]
for (newIndex, key) in sortedKeys.enumerated() {
mapping[key] = newIndex
let entry = grouped[key]!
if entry.count > 0 {
centroids.append(entry.sum.map { $0 / Double(entry.count) })
} else {
centroids.append(entry.sum)
}
}
return (centroids, mapping)
}
private func computeFallbackCentroids(from embeddings: [[Double]]) -> [[Double]] {
guard let first = embeddings.first else { return [] }
var accumulator = [Double](repeating: 0, count: first.count)
let countIndex = makeBlasIndexOrFatal(first.count, label: "fallback centroid length")
let unitStride = BlasIndex(1)
for vector in embeddings {
precondition(
vector.count == first.count,
"Jagged training embeddings are not supported"
)
vector.withUnsafeBufferPointer { sourcePointer in
accumulator.withUnsafeMutableBufferPointer { destinationPointer in
guard
let sourceBase = sourcePointer.baseAddress,
let destinationBase = destinationPointer.baseAddress
else { return }
cblas_daxpy(
countIndex,
1.0,
sourceBase,
unitStride,
destinationBase,
unitStride
)
}
}
}
var scale = 1.0 / Double(embeddings.count)
accumulator.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
vDSP_vsmulD(
baseAddress,
1,
&scale,
baseAddress,
1,
vDSP_Length(first.count)
)
}
return [accumulator]
}
private func assignEmbeddings(
embeddingFeatures: [[Double]],
centroids: [[Double]]
) -> [Int] {
guard !embeddingFeatures.isEmpty else { return [] }
guard !centroids.isEmpty else {
return Array(repeating: 0, count: embeddingFeatures.count)
}
let normalizedCentroids = centroids.map(normalize)
return embeddingFeatures.map { embedding in
let normalizedEmbedding = normalize(embedding)
var bestIndex = 0
var bestScore = -Double.infinity
for (index, centroid) in normalizedCentroids.enumerated() {
let score = dot(normalizedEmbedding, centroid)
if score > bestScore {
bestScore = score
bestIndex = index
}
}
return bestIndex
}
}
private func normalize(_ vector: [Double]) -> [Double] {
guard !vector.isEmpty else { return vector }
var sumSquares = 0.0
vector.withUnsafeBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
vDSP_dotprD(
baseAddress,
1,
baseAddress,
1,
&sumSquares,
vDSP_Length(vector.count)
)
}
if sumSquares <= 0 {
return vector
}
var scale = 1.0 / sqrt(sumSquares)
var normalized = [Double](repeating: 0, count: vector.count)
vector.withUnsafeBufferPointer { sourcePointer in
normalized.withUnsafeMutableBufferPointer { destinationPointer in
guard
let sourceBase = sourcePointer.baseAddress,
let destinationBase = destinationPointer.baseAddress
else { return }
vDSP_vsmulD(
sourceBase,
1,
&scale,
destinationBase,
1,
vDSP_Length(vector.count)
)
}
}
return normalized
}
private func dot(_ lhs: [Double], _ rhs: [Double]) -> Double {
guard lhs.count == rhs.count else { return 0 }
if lhs.isEmpty { return 0 }
var result = 0.0
lhs.withUnsafeBufferPointer { lhsPointer in
rhs.withUnsafeBufferPointer { rhsPointer in
guard
let lhsBase = lhsPointer.baseAddress,
let rhsBase = rhsPointer.baseAddress
else { return }
vDSP_dotprD(
lhsBase,
1,
rhsBase,
1,
&result,
vDSP_Length(lhs.count)
)
}
}
return result
}
private func buildChunkAssignments(
segmentation: SegmentationOutput,
timedEmbeddings: [TimedEmbedding],
assignments: [Int],
clusterCount: Int
) -> [[Int]] {
var matrix = Array(
repeating: Array(repeating: -2, count: segmentation.numSpeakers),
count: segmentation.numChunks
)
for (embedding, cluster) in zip(timedEmbeddings, assignments) {
guard
embedding.chunkIndex >= 0,
embedding.chunkIndex < matrix.count,
embedding.speakerIndex >= 0,
embedding.speakerIndex < matrix[embedding.chunkIndex].count,
cluster >= 0,
cluster < clusterCount
else {
continue
}
matrix[embedding.chunkIndex][embedding.speakerIndex] = cluster
}
return matrix
}
private func exportEmbeddings(
embeddings: [TimedEmbedding],
assignments: [Int],
path: String
) throws {
struct ExportPayload: Codable {
let chunkIndex: Int
let speakerIndex: Int
let startFrame: Int
let endFrame: Int
let startTime: Double
let endTime: Double
let embedding256: [Float]
let rho128: [Double]
let cluster: Int
}
var payload: [ExportPayload] = []
payload.reserveCapacity(embeddings.count)
for (index, embedding) in embeddings.enumerated() {
let cluster =
assignments.indices.contains(index)
? assignments[index] : -1
payload.append(
ExportPayload(
chunkIndex: embedding.chunkIndex,
speakerIndex: embedding.speakerIndex,
startFrame: embedding.startFrame,
endFrame: embedding.endFrame,
startTime: embedding.startTime,
endTime: embedding.endTime,
embedding256: embedding.embedding256,
rho128: embedding.rho128,
cluster: cluster
)
)
}
let data = try JSONEncoder().encode(payload)
let url = URL(fileURLWithPath: path)
try data.write(to: url)
logger.info("Exported \(payload.count) embeddings to \(path)")
}
}
@@ -0,0 +1,164 @@
@preconcurrency import CoreML
import Foundation
import OSLog
@available(macOS 14.0, iOS 17.0, *)
public struct OfflineDiarizerModels: Sendable {
public let segmentationModel: MLModel
public let fbankModel: MLModel
public let embeddingModel: MLModel
public let pldaRhoModel: MLModel
public let pldaPsi: [Double]
public let compilationDuration: TimeInterval
private static let logger = AppLogger(category: "OfflineDiarizerModels")
private static func loadPLDAPsi(from directory: URL) throws -> [Double] {
let candidatePaths = [
directory.appendingPathComponent("plda-parameters.json", isDirectory: false),
directory.appendingPathComponent("speaker-diarization-coreml/plda-parameters.json", isDirectory: false),
directory.appendingPathComponent("speaker-diarization-offline/plda-parameters.json", isDirectory: false),
]
guard let parametersURL = candidatePaths.first(where: { FileManager.default.fileExists(atPath: $0.path) })
else {
throw OfflineDiarizationError.processingFailed("PLDA parameters file not found in \(directory.path)")
}
let data = try Data(contentsOf: parametersURL)
let jsonObject = try JSONSerialization.jsonObject(with: data, options: [])
guard
let root = jsonObject as? [String: Any],
let tensors = root["tensors"] as? [String: Any],
let psiInfo = tensors["psi"] as? [String: Any],
let base64 = psiInfo["data_base64"] as? String,
let decoded = Data(base64Encoded: base64, options: [.ignoreUnknownCharacters])
else {
throw OfflineDiarizationError.processingFailed("Failed to decode PLDA psi parameters")
}
let floatCount = decoded.count / MemoryLayout<Float>.size
guard floatCount > 0 else {
throw OfflineDiarizationError.processingFailed("PLDA psi tensor is empty")
}
var floats = [Float](repeating: 0, count: floatCount)
_ = floats.withUnsafeMutableBytes { destination in
decoded.copyBytes(to: destination)
}
return floats.map { Double($0) }
}
public init(
segmentationModel: MLModel,
fbankModel: MLModel,
embeddingModel: MLModel,
pldaRhoModel: MLModel,
pldaPsi: [Double],
compilationDuration: TimeInterval
) {
self.segmentationModel = segmentationModel
self.fbankModel = fbankModel
self.embeddingModel = embeddingModel
self.pldaRhoModel = pldaRhoModel
self.pldaPsi = pldaPsi
self.compilationDuration = compilationDuration
}
public static func defaultModelsDirectory() -> URL {
let base = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
return
base
.appendingPathComponent("FluidAudio", isDirectory: true)
.appendingPathComponent("Models", isDirectory: true)
}
private static func defaultConfiguration() -> MLModelConfiguration {
let configuration = MLModelConfiguration()
configuration.allowLowPrecisionAccumulationOnGPU = true
configuration.computeUnits = .all
return configuration
}
public static func load(
from directory: URL? = nil,
configuration: MLModelConfiguration? = nil
) async throws -> OfflineDiarizerModels {
let modelsDirectory = directory ?? defaultModelsDirectory()
let logger = Self.logger
logger.info("Loading offline diarization models from \(modelsDirectory.path)")
let loadStart = Date()
let inferenceComputeUnits: MLComputeUnits = .all
let segmentationAndEmbeddingNames: [String] = [
ModelNames.OfflineDiarizer.segmentationPath,
ModelNames.OfflineDiarizer.embeddingPath,
ModelNames.OfflineDiarizer.pldaRhoPath,
]
let segmentationEmbeddingModels = try await DownloadUtils.loadModels(
.diarizer,
modelNames: segmentationAndEmbeddingNames,
directory: modelsDirectory,
computeUnits: inferenceComputeUnits,
variant: "offline"
)
guard let segmentation = segmentationEmbeddingModels[ModelNames.OfflineDiarizer.segmentationPath] else {
throw OfflineDiarizationError.modelNotLoaded(ModelNames.OfflineDiarizer.segmentation)
}
guard let embedding = segmentationEmbeddingModels[ModelNames.OfflineDiarizer.embeddingPath] else {
throw OfflineDiarizationError.modelNotLoaded(ModelNames.OfflineDiarizer.embedding)
}
guard let plda = segmentationEmbeddingModels[ModelNames.OfflineDiarizer.pldaRhoPath] else {
throw OfflineDiarizationError.modelNotLoaded(ModelNames.OfflineDiarizer.pldaRho)
}
let fbankComputeUnits: MLComputeUnits = .cpuOnly
let fbankModels = try await DownloadUtils.loadModels(
.diarizer,
modelNames: [ModelNames.OfflineDiarizer.fbankPath],
directory: modelsDirectory,
computeUnits: fbankComputeUnits,
variant: "offline"
)
guard let fbank = fbankModels[ModelNames.OfflineDiarizer.fbankPath] else {
throw OfflineDiarizationError.modelNotLoaded(ModelNames.OfflineDiarizer.fbank)
}
let pldaPsi = try loadPLDAPsi(from: modelsDirectory)
let compilationDuration = Date().timeIntervalSince(loadStart)
let compileString = String(format: "%.3f", compilationDuration)
logger.info(
"Offline diarization models ready (compile: \(compileString)s, computeUnits: segmentation/embedding/plda=\(inferenceComputeUnits.label), fbank=\(fbankComputeUnits.label))"
)
return OfflineDiarizerModels(
segmentationModel: segmentation,
fbankModel: fbank,
embeddingModel: embedding,
pldaRhoModel: plda,
pldaPsi: pldaPsi,
compilationDuration: compilationDuration
)
}
}
extension MLComputeUnits {
fileprivate var label: String {
switch self {
case .cpuOnly:
return ".cpuOnly"
case .cpuAndGPU:
return ".cpuAndGPU"
case .cpuAndNeuralEngine:
return ".cpuAndNeuralEngine"
case .all:
return ".all"
@unknown default:
return ".unknown"
}
}
}
@@ -0,0 +1,548 @@
import Foundation
/// Errors surfaced by the offline diarization pipeline.
public enum OfflineDiarizationError: Error, LocalizedError {
case modelNotLoaded(String)
case invalidConfiguration(String)
case invalidBatchSize(String)
case processingFailed(String)
case noSpeechDetected
case exportFailed(String)
public var errorDescription: String? {
switch self {
case .modelNotLoaded(let name):
return "Model not loaded: \(name)"
case .invalidConfiguration(let message):
return "Invalid configuration: \(message)"
case .invalidBatchSize(let message):
return "Invalid batch size: \(message)"
case .processingFailed(let message):
return "Processing failed: \(message)"
case .noSpeechDetected:
return "No speech detected in audio"
case .exportFailed(let message):
return "Failed to export data: \(message)"
}
}
}
/// Configuration values tuned to pyannote's community-1 pipeline.
/// Groups knobs by pipeline stage while keeping legacy property accessors
/// to minimize downstream churn.
public struct OfflineDiarizerConfig: Sendable {
/// Segmentation parameters. Threshold fields are ignored by powerset models like community-1 but
/// remain for compatibility with non-powerset pipelines.
public struct Segmentation: Sendable {
public var windowDurationSeconds: Double
public var sampleRate: Int
public var minDurationOn: Double
public var minDurationOff: Double
public var stepRatio: Double
public var speechOnsetThreshold: Float
public var speechOffsetThreshold: Float
public static let community = Segmentation(
windowDurationSeconds: 10.0,
sampleRate: 16_000,
minDurationOn: 0.0,
minDurationOff: 0.0,
// This with 1.0 min speech duration gives us ~1.4% worse DER but 2x the speed.
stepRatio: 0.2,
speechOnsetThreshold: 0.5,
speechOffsetThreshold: 0.5
)
public init(
windowDurationSeconds: Double,
sampleRate: Int,
minDurationOn: Double,
minDurationOff: Double,
stepRatio: Double,
speechOnsetThreshold: Float,
speechOffsetThreshold: Float
) {
self.windowDurationSeconds = windowDurationSeconds
self.sampleRate = sampleRate
self.minDurationOn = minDurationOn
self.minDurationOff = minDurationOff
self.stepRatio = stepRatio
self.speechOnsetThreshold = speechOnsetThreshold
self.speechOffsetThreshold = speechOffsetThreshold
}
}
public struct Embedding: Sendable {
public var batchSize: Int
public var excludeOverlap: Bool
public var minSegmentDurationSeconds: Double
public static let community = Embedding(
batchSize: 32,
excludeOverlap: true,
minSegmentDurationSeconds: 1.0
)
public init(
batchSize: Int,
excludeOverlap: Bool,
minSegmentDurationSeconds: Double
) {
self.batchSize = batchSize
self.excludeOverlap = excludeOverlap
self.minSegmentDurationSeconds = minSegmentDurationSeconds
}
}
public struct Clustering: Sendable {
/// Euclidean distance threshold for unit-normalized embeddings.
public var threshold: Double
/// VBx warm-start parameters (Fa controls precision, Fb controls recall)
public var warmStartFa: Double
public var warmStartFb: Double
// NOTE: minClusterSize is NOT used in community-1 (VBx-based pipeline).
// VBx is designed to handle 100+ under-clustered initial assignments from AHC
// and naturally merge them during Bayesian refinement. Pre-merging small clusters
// with minClusterSize enforcement (used in pyannote 3.1 AHC-only pipeline)
// is counterproductive for VBx and loses speaker distinctions.
public static let community = Clustering(
threshold: 0.6,
// Default 0.07
warmStartFa: 0.07,
warmStartFb: 0.8
)
public init(
threshold: Double,
warmStartFa: Double,
warmStartFb: Double
) {
self.threshold = threshold
self.warmStartFa = warmStartFa
self.warmStartFb = warmStartFb
}
}
public struct VBx: Sendable {
public var maxIterations: Int
public var convergenceTolerance: Double
// Default values from pyannote.community-1
public static let community = VBx(
maxIterations: 20,
convergenceTolerance: 1e-4
)
public init(
maxIterations: Int,
convergenceTolerance: Double
) {
self.maxIterations = maxIterations
self.convergenceTolerance = convergenceTolerance
}
}
public struct PostProcessing: Sendable {
public var minGapDurationSeconds: Double
public static let community = PostProcessing(minGapDurationSeconds: 0.1)
public init(minGapDurationSeconds: Double) {
self.minGapDurationSeconds = minGapDurationSeconds
}
}
public struct Export: Sendable {
public var embeddingsPath: String?
public init(embeddingsPath: String? = nil) {
self.embeddingsPath = embeddingsPath
}
public static let none = Export()
}
public var segmentation: Segmentation
public var embedding: Embedding
public var clustering: Clustering
public var vbx: VBx
public var postProcessing: PostProcessing
public var export: Export
public init(
segmentation: Segmentation = .community,
embedding: Embedding = .community,
clustering: Clustering = .community,
vbx: VBx = .community,
postProcessing: PostProcessing = .community,
export: Export = .none
) {
self.segmentation = segmentation
self.embedding = embedding
self.clustering = clustering
self.vbx = vbx
self.postProcessing = postProcessing
self.export = export
}
public init(
clusteringThreshold: Double = Clustering.community.threshold,
Fa: Double = Clustering.community.warmStartFa,
Fb: Double = Clustering.community.warmStartFb,
windowDuration: Double = Segmentation.community.windowDurationSeconds,
sampleRate: Int = Segmentation.community.sampleRate,
segmentationStepRatio: Double = Segmentation.community.stepRatio,
embeddingBatchSize: Int = Embedding.community.batchSize,
embeddingExcludeOverlap: Bool = Embedding.community.excludeOverlap,
minSegmentDuration: Double = Embedding.community.minSegmentDurationSeconds,
minGapDuration: Double = PostProcessing.community.minGapDurationSeconds,
speechOnsetThreshold: Float = Segmentation.community.speechOnsetThreshold,
speechOffsetThreshold: Float = Segmentation.community.speechOffsetThreshold,
segmentationMinDurationOn: Double = Segmentation.community.minDurationOn,
segmentationMinDurationOff: Double = Segmentation.community.minDurationOff,
maxVBxIterations: Int = VBx.community.maxIterations,
convergenceTolerance: Double = VBx.community.convergenceTolerance,
embeddingExportPath: String? = nil
) {
self.init(
segmentation: Segmentation(
windowDurationSeconds: windowDuration,
sampleRate: sampleRate,
minDurationOn: segmentationMinDurationOn,
minDurationOff: segmentationMinDurationOff,
stepRatio: segmentationStepRatio,
speechOnsetThreshold: speechOnsetThreshold,
speechOffsetThreshold: speechOffsetThreshold
),
embedding: Embedding(
batchSize: embeddingBatchSize,
excludeOverlap: embeddingExcludeOverlap,
minSegmentDurationSeconds: minSegmentDuration
),
clustering: Clustering(
threshold: clusteringThreshold,
warmStartFa: Fa,
warmStartFb: Fb
),
vbx: VBx(
maxIterations: maxVBxIterations,
convergenceTolerance: convergenceTolerance
),
postProcessing: PostProcessing(minGapDurationSeconds: minGapDuration),
export: Export(embeddingsPath: embeddingExportPath)
)
}
/// Number of samples processed per segmentation window.
public var samplesPerWindow: Int {
Int(Double(segmentation.sampleRate) * segmentation.windowDurationSeconds)
}
public var samplesPerStep: Int {
max(1, Int(Double(samplesPerWindow) * segmentation.stepRatio))
}
/// Validate configuration values and throw if they fall outside expected ranges.
public func validate() throws {
let maxClusteringThreshold = sqrt(2.0)
guard clustering.threshold > 0, clustering.threshold <= maxClusteringThreshold else {
throw OfflineDiarizationError.invalidConfiguration(
"clustering.threshold must be within (0, sqrt(2)], got \(clustering.threshold)"
)
}
guard clustering.warmStartFa > 0, clustering.warmStartFb > 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"clustering warm-start Fa/Fb must be positive (Fa=\(clustering.warmStartFa), Fb=\(clustering.warmStartFb))"
)
}
guard segmentation.windowDurationSeconds > 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"segmentation.windowDurationSeconds must be positive, got \(segmentation.windowDurationSeconds)"
)
}
guard segmentation.sampleRate > 0 else {
throw OfflineDiarizationError.invalidConfiguration("sampleRate must be positive")
}
guard segmentation.stepRatio > 0, segmentation.stepRatio <= 1 else {
throw OfflineDiarizationError.invalidConfiguration(
"segmentation.stepRatio must be within (0, 1], got \(segmentation.stepRatio)"
)
}
guard embedding.batchSize > 0 else {
throw OfflineDiarizationError.invalidBatchSize("embeddingBatchSize must be > 0")
}
guard embedding.batchSize <= 32 else {
throw OfflineDiarizationError.invalidBatchSize(
"embeddingBatchSize must be <= 32 to fit PLDA batch limits"
)
}
guard vbx.maxIterations > 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"maxVBxIterations must be > 0, got \(vbx.maxIterations)"
)
}
guard vbx.convergenceTolerance > 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"convergenceTolerance must be positive"
)
}
guard embedding.minSegmentDurationSeconds >= 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"embedding.minSegmentDuration must be >= 0"
)
}
guard postProcessing.minGapDurationSeconds >= 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"minGapDuration must be >= 0"
)
}
guard segmentation.minDurationOn >= 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"segmentation.minDurationOn must be >= 0"
)
}
guard segmentation.minDurationOff >= 0 else {
throw OfflineDiarizationError.invalidConfiguration(
"segmentation.minDurationOff must be >= 0"
)
}
guard segmentation.speechOnsetThreshold >= 0, segmentation.speechOnsetThreshold <= 1 else {
throw OfflineDiarizationError.invalidConfiguration(
"speechOnsetThreshold must be within [0, 1], got \(segmentation.speechOnsetThreshold)"
)
}
guard segmentation.speechOffsetThreshold >= 0,
segmentation.speechOffsetThreshold <= segmentation.speechOnsetThreshold
else {
throw OfflineDiarizationError.invalidConfiguration(
"speechOffsetThreshold must be within [0, speechOnsetThreshold], got \(segmentation.speechOffsetThreshold)"
)
}
}
public var clusteringThreshold: Double {
get { clustering.threshold }
set { clustering.threshold = newValue }
}
public var Fa: Double {
get { clustering.warmStartFa }
set { clustering.warmStartFa = newValue }
}
public var Fb: Double {
get { clustering.warmStartFb }
set { clustering.warmStartFb = newValue }
}
public var windowDuration: Double {
get { segmentation.windowDurationSeconds }
set { segmentation.windowDurationSeconds = newValue }
}
public var sampleRate: Int {
get { segmentation.sampleRate }
set { segmentation.sampleRate = newValue }
}
public var embeddingBatchSize: Int {
get { embedding.batchSize }
set { embedding.batchSize = newValue }
}
public var maxVBxIterations: Int {
get { vbx.maxIterations }
set { vbx.maxIterations = newValue }
}
public var convergenceTolerance: Double {
get { vbx.convergenceTolerance }
set { vbx.convergenceTolerance = newValue }
}
public var embeddingExcludeOverlap: Bool {
get { embedding.excludeOverlap }
set { embedding.excludeOverlap = newValue }
}
@available(*, deprecated, renamed: "embeddingExcludeOverlap")
public var shouldExcludeOverlaps: Bool {
get { embeddingExcludeOverlap }
set { embeddingExcludeOverlap = newValue }
}
public var minSegmentDuration: Double {
get { embedding.minSegmentDurationSeconds }
set { embedding.minSegmentDurationSeconds = newValue }
}
public var minGapDuration: Double {
get { postProcessing.minGapDurationSeconds }
set { postProcessing.minGapDurationSeconds = newValue }
}
public var embeddingExportPath: String? {
get { export.embeddingsPath }
set { export.embeddingsPath = newValue }
}
public var speechOnsetThreshold: Float {
get { segmentation.speechOnsetThreshold }
set { segmentation.speechOnsetThreshold = newValue }
}
public var speechOffsetThreshold: Float {
get { segmentation.speechOffsetThreshold }
set { segmentation.speechOffsetThreshold = newValue }
}
public var segmentationMinDurationOn: Double {
get { segmentation.minDurationOn }
set { segmentation.minDurationOn = newValue }
}
public var segmentationMinDurationOff: Double {
get { segmentation.minDurationOff }
set { segmentation.minDurationOff = newValue }
}
public var segmentationStepRatio: Double {
get { segmentation.stepRatio }
set { segmentation.stepRatio = newValue }
}
public static var `default`: OfflineDiarizerConfig {
OfflineDiarizerConfig()
}
}
/// Raw segmentation logits over the local powerset predictions for each chunk.
@available(macOS 13.0, iOS 16.0, *)
struct SegmentationLogits: Sendable {
let chunkIndex: Int
let startSample: Int
let endSample: Int
let logits: [[Float]] // frames × classes
}
/// Segmentation output aggregated across all processed windows.
@available(macOS 13.0, iOS 16.0, *)
public struct SegmentationOutput: Sendable {
public let logProbs: [[[Float]]]
/// Soft speaker activity weights per chunk/frame/speaker (0.0...1.0 values).
public let speakerWeights: [[[Float]]]
public let numChunks: Int
public let numFrames: Int
public let numSpeakers: Int
public let chunkOffsets: [Double]
public let frameDuration: Double
public init(
logProbs: [[[Float]]],
speakerWeights: [[[Float]]] = [],
numChunks: Int,
numFrames: Int,
numSpeakers: Int,
chunkOffsets: [Double] = [],
frameDuration: Double = 0
) {
self.logProbs = logProbs
self.speakerWeights = speakerWeights
self.numChunks = numChunks
self.numFrames = numFrames
self.numSpeakers = numSpeakers
self.chunkOffsets = chunkOffsets
self.frameDuration = frameDuration
}
}
/// Incremental segmentation chunk emitted while the model processes audio.
@available(macOS 13.0, iOS 16.0, *)
public struct SegmentationChunk: Sendable {
public let chunkIndex: Int
public let chunkOffsetSeconds: Double
public let frameDuration: Double
public let logProbs: [[Float]]
public let speakerWeights: [[Float]]
public init(
chunkIndex: Int,
chunkOffsetSeconds: Double,
frameDuration: Double,
logProbs: [[Float]],
speakerWeights: [[Float]]
) {
self.chunkIndex = chunkIndex
self.chunkOffsetSeconds = chunkOffsetSeconds
self.frameDuration = frameDuration
self.logProbs = logProbs
self.speakerWeights = speakerWeights
}
}
enum SegmentationChunkContinuation: Sendable {
case `continue`
case stop
}
typealias SegmentationChunkHandler = @Sendable (SegmentationChunk) -> SegmentationChunkContinuation
/// Result returned by the VBx refinement step.
@available(macOS 13.0, iOS 16.0, *)
public struct VBxOutput: Sendable {
public let gamma: [[Double]]
public let pi: [Double]
public let hardClusters: [[Int]]
public let centroids: [[Double]]
public let numClusters: Int
public let elbos: [Double]
public init(
gamma: [[Double]],
pi: [Double],
hardClusters: [[Int]],
centroids: [[Double]],
numClusters: Int,
elbos: [Double]
) {
self.gamma = gamma
self.pi = pi
self.hardClusters = hardClusters
self.centroids = centroids
self.numClusters = numClusters
self.elbos = elbos
}
}
/// Intermediate representation of an embedding associated with its timeline.
@available(macOS 13.0, iOS 16.0, *)
struct TimedEmbedding: Sendable {
let chunkIndex: Int
let speakerIndex: Int
let startFrame: Int
let endFrame: Int
let frameWeights: [Float]
let startTime: Double
let endTime: Double
let embedding256: [Float]
let rho128: [Double]
}
@@ -0,0 +1,887 @@
import Accelerate
import CoreML
import Foundation
import OSLog
private struct OfflineEmbeddingPending: Sendable {
let chunkIndex: Int
let speakerIndex: Int
let startFrame: Int
let endFrame: Int
let frameWeights: [Float]
let startTime: Double
let endTime: Double
let embedding256: [Float]
init(
chunkIndex: Int,
speakerIndex: Int,
startFrame: Int,
endFrame: Int,
frameWeights: [Float],
startTime: Double,
endTime: Double,
embedding256: [Float]
) {
self.chunkIndex = chunkIndex
self.speakerIndex = speakerIndex
self.startFrame = startFrame
self.endFrame = endFrame
self.frameWeights = frameWeights
self.startTime = startTime
self.endTime = endTime
self.embedding256 = embedding256
}
}
private struct OfflineChunkBatchInfo: Sendable {
let chunkIndex: Int
let chunkOffsetSeconds: Double
let frameDuration: Double
let speakerWeights: [[Float]]
init(
chunkIndex: Int,
chunkOffsetSeconds: Double,
frameDuration: Double,
speakerWeights: [[Float]]
) {
self.chunkIndex = chunkIndex
self.chunkOffsetSeconds = chunkOffsetSeconds
self.frameDuration = frameDuration
self.speakerWeights = speakerWeights
}
}
struct OfflineEmbeddingExtractor {
private let fbankModel: MLModel
private let embeddingModel: MLModel
private let pldaTransform: PLDATransform
private let config: OfflineDiarizerConfig
private let logger = AppLogger(category: "OfflineEmbedding")
private let memoryOptimizer = ANEMemoryOptimizer()
private let fbankInputName: String
private let fbankOutputName: String
private let fbankFeatureName: String
private let weightInputName: String
private let fbankInputShape: [NSNumber]
private let weightInputShape: [NSNumber]
private let audioSampleCount: Int
private let weightFrameCount: Int
private let modelBatchLimit: Int
private let embeddingOutputName: String
init(
fbankModel: MLModel,
embeddingModel: MLModel,
pldaTransform: PLDATransform,
config: OfflineDiarizerConfig
) {
self.fbankModel = fbankModel
self.embeddingModel = embeddingModel
self.pldaTransform = pldaTransform
self.config = config
// Resolve FBANK input metadata
let fbankDescription = fbankModel.modelDescription
guard
let audioInput = fbankDescription.inputDescriptionsByName["audio"],
let audioConstraint = audioInput.multiArrayConstraint
else {
logger.error("FBANK model is missing `audio` multiarray input; required for offline pipeline")
preconditionFailure("FBANK model must expose an `audio` MLMultiArray input")
}
self.fbankInputName = "audio"
let resolvedAudioSamples = OfflineEmbeddingExtractor.resolveElementCount(
from: audioConstraint,
fallback: config.samplesPerWindow
)
self.audioSampleCount = max(1, min(config.samplesPerWindow, resolvedAudioSamples))
let audioFallbackShape = OfflineEmbeddingExtractor.defaultShape(
dimensionHint: OfflineEmbeddingExtractor.dimensionHint(for: audioConstraint),
minimumCount: 3,
lastDimension: self.audioSampleCount
)
self.fbankInputShape = OfflineEmbeddingExtractor.sanitizedShape(
from: audioConstraint,
fallback: audioFallbackShape
)
// Resolve FBANK output metadata
guard
let fbankOutput = fbankDescription.outputDescriptionsByName["fbank_features"],
fbankOutput.type == .multiArray
else {
logger.error("FBANK model missing `fbank_features` multiarray output")
preconditionFailure("FBANK model must expose `fbank_features` multiarray output")
}
self.fbankOutputName = "fbank_features"
// Resolve embedding model inputs
let embeddingDescription = embeddingModel.modelDescription
let embeddingInputs = embeddingDescription.inputDescriptionsByName
guard
let fbankFeatureInput = embeddingInputs["fbank_features"],
fbankFeatureInput.type == .multiArray
else {
logger.error("Embedding model missing `fbank_features` multiarray input")
preconditionFailure("Embedding model must expose `fbank_features` multiarray input")
}
self.fbankFeatureName = "fbank_features"
guard
let weightInput = embeddingInputs["weights"],
let weightConstraint = weightInput.multiArrayConstraint
else {
logger.error("Embedding model missing `weights` multiarray input")
preconditionFailure("Embedding model must expose `weights` MLMultiArray input")
}
self.weightInputName = "weights"
let resolvedWeightFrames = OfflineEmbeddingExtractor.resolveElementCount(
from: weightConstraint,
fallback: 589
)
self.weightFrameCount = max(1, resolvedWeightFrames)
let weightFallbackShape = OfflineEmbeddingExtractor.defaultShape(
dimensionHint: OfflineEmbeddingExtractor.dimensionHint(for: weightConstraint),
minimumCount: 2,
lastDimension: self.weightFrameCount
)
self.weightInputShape = OfflineEmbeddingExtractor.sanitizedShape(
from: weightConstraint,
fallback: weightFallbackShape
)
self.modelBatchLimit = max(1, min(config.embeddingBatchSize, 32))
guard embeddingDescription.outputDescriptionsByName["embedding"] != nil else {
logger.error("Embedding model missing `embedding` multiarray output")
preconditionFailure("Embedding model must expose `embedding` multiarray output")
}
self.embeddingOutputName = "embedding"
let audioShapeString = fbankInputShape.map { "\($0.intValue)" }.joined(separator: "×")
let weightShapeString = weightInputShape.map { "\($0.intValue)" }.joined(separator: "×")
logger.debug(
"Offline embedding configured with FBANK input \(fbankInputName)[\(audioShapeString)] → \(fbankOutputName); embedding consumes \(fbankFeatureName) + \(weightInputName)[\(weightShapeString)] (frames=\(weightFrameCount)), maxBatch=\(modelBatchLimit), output=\(embeddingOutputName)"
)
}
func extractEmbeddings(
audio: [Float],
segmentation: SegmentationOutput
) async throws -> [TimedEmbedding] {
try await extractEmbeddings(
audioSource: ArrayAudioSampleSource(samples: audio),
segmentation: segmentation
)
}
func extractEmbeddings(
audioSource: StreamingAudioSampleSource,
segmentation: SegmentationOutput
) async throws -> [TimedEmbedding] {
let stream = AsyncThrowingStream<SegmentationChunk, Error> { continuation in
for chunkIndex in 0..<segmentation.numChunks {
guard segmentation.speakerWeights.indices.contains(chunkIndex) else { continue }
let chunkSpeakerWeights = segmentation.speakerWeights[chunkIndex]
guard !chunkSpeakerWeights.isEmpty else { continue }
let chunkOffsetSeconds: Double
if segmentation.chunkOffsets.indices.contains(chunkIndex) {
chunkOffsetSeconds = segmentation.chunkOffsets[chunkIndex]
} else {
chunkOffsetSeconds = Double(chunkIndex) * config.windowDuration
}
let chunkLogProbs: [[Float]]
if segmentation.logProbs.indices.contains(chunkIndex) {
chunkLogProbs = segmentation.logProbs[chunkIndex]
} else {
chunkLogProbs = []
}
let chunk = SegmentationChunk(
chunkIndex: chunkIndex,
chunkOffsetSeconds: chunkOffsetSeconds,
frameDuration: segmentation.frameDuration,
logProbs: chunkLogProbs,
speakerWeights: chunkSpeakerWeights
)
continuation.yield(chunk)
}
continuation.finish()
}
return try await extractEmbeddings(
audioSource: audioSource,
segmentationStream: stream
)
}
func extractEmbeddings<S: AsyncSequence>(
audioSource: StreamingAudioSampleSource,
segmentationStream: S
) async throws -> [TimedEmbedding] where S.Element == SegmentationChunk {
var embeddings: [TimedEmbedding] = []
embeddings.reserveCapacity(config.embeddingBatchSize * 8)
let overlapThreshold: Float = 1e-3
let maxPLDABatch = max(1, min(config.embeddingBatchSize, modelBatchLimit))
let fbankBatchLimit = min(modelBatchLimit, 32)
var pendingEmbeddings: [[Float]] = []
var pendingMetadata: [OfflineEmbeddingPending] = []
pendingEmbeddings.reserveCapacity(maxPLDABatch)
pendingMetadata.reserveCapacity(maxPLDABatch)
var processedMasks = 0
var fallbackMaskCount = 0
var emptyMaskCount = 0
var accumulatedMaskFrames: Double = 0
let chunkSize = config.samplesPerWindow
var chunkBuffer = [Float](repeating: 0, count: chunkSize)
let totalSamples = audioSource.sampleCount
var batchAudioInputs: [MLMultiArray] = []
var batchInfos: [OfflineChunkBatchInfo] = []
batchAudioInputs.reserveCapacity(fbankBatchLimit)
batchInfos.reserveCapacity(fbankBatchLimit)
let clock = ContinuousClock()
var fbankDuration: Duration = .zero
var maskPreparationDuration: Duration = .zero
var resampleDuration: Duration = .zero
var embeddingDuration: Duration = .zero
var pldaDuration: Duration = .zero
var evaluatedMaskCount = 0
var fbankWindowCount = 0
var fbankBatchCallCount = 0
var pldaOutputCount = 0
var pldaBatchCallCount = 0
func performEmbeddingWarmup() throws {
let warmupAudioArray = try memoryOptimizer.getPooledBuffer(
key: "offline_embedding_warmup_audio",
shape: fbankInputShape,
dataType: .float32
)
let warmupAudioPointer = warmupAudioArray.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(warmupAudioPointer, 1, vDSP_Length(warmupAudioArray.count))
let warmupFbankFeatures = try runFbankModel(audioArray: warmupAudioArray)
let zeroWeights = [Float](repeating: 0, count: weightFrameCount)
let warmupWeightsArray = try prepareWeightsInput(weights: zeroWeights)
_ = try runEmbeddingModel(
fbankFeatures: warmupFbankFeatures,
weightsArray: warmupWeightsArray
)
}
func resolveFrameDuration(_ chunk: SegmentationChunk) -> Double {
if chunk.frameDuration > 0 {
return chunk.frameDuration
}
let frameCount = max(1, chunk.speakerWeights.count)
guard frameCount > 0 else {
return max(1e-3, config.windowDuration)
}
return config.windowDuration / Double(frameCount)
}
func requiredMinFrames(for frameDuration: Double) -> Int {
guard frameDuration > 0 else {
return 1
}
let count = Int(ceil(config.minSegmentDuration / frameDuration))
return max(1, count)
}
func flushPending() async throws {
guard !pendingEmbeddings.isEmpty else { return }
let pldaStart = clock.now
let rhoBatch = try await pldaTransform.transform(pendingEmbeddings)
pldaDuration += pldaStart.duration(to: clock.now)
pldaOutputCount += rhoBatch.count
pldaBatchCallCount += 1
guard rhoBatch.count == pendingMetadata.count else {
throw OfflineDiarizationError.processingFailed(
"PldaRho batch size mismatch (expected \(pendingMetadata.count), got \(rhoBatch.count))"
)
}
for (info, rho) in zip(pendingMetadata, rhoBatch) {
let timedEmbedding = TimedEmbedding(
chunkIndex: info.chunkIndex,
speakerIndex: info.speakerIndex,
startFrame: info.startFrame,
endFrame: info.endFrame,
frameWeights: info.frameWeights,
startTime: info.startTime,
endTime: info.endTime,
embedding256: info.embedding256,
rho128: rho
)
embeddings.append(timedEmbedding)
}
pendingEmbeddings.removeAll(keepingCapacity: true)
pendingMetadata.removeAll(keepingCapacity: true)
}
func processChunk(info: OfflineChunkBatchInfo, fbankFeatures: MLMultiArray) async throws {
let chunkSpeakerWeights = info.speakerWeights
guard !chunkSpeakerWeights.isEmpty else { return }
let frameCount = chunkSpeakerWeights.count
guard let speakerCount = chunkSpeakerWeights.first?.count, speakerCount > 0 else { return }
let minFramesForEmbedding = requiredMinFrames(for: info.frameDuration)
var baseMask = [Float](repeating: 0, count: frameCount)
var cleanMask = [Float](repeating: 0, count: frameCount)
let overlapFrames: [Bool]
if config.embeddingExcludeOverlap {
var frames = [Bool](repeating: false, count: frameCount)
for (frame, weights) in chunkSpeakerWeights.enumerated() {
var active = 0
for value in weights where value > overlapThreshold {
active += 1
if active > 1 {
frames[frame] = true
break
}
}
}
overlapFrames = frames
} else {
overlapFrames = []
}
let totalWeightCount = frameCount * speakerCount
guard totalWeightCount > 0 else { return }
let rowMajorShape: [NSNumber] = [
NSNumber(value: frameCount),
NSNumber(value: speakerCount),
]
let rowMajorKey = "offline_embedding_row_major_\(frameCount)_\(speakerCount)"
let rowMajorBuffer = try memoryOptimizer.getPooledBuffer(
key: rowMajorKey,
shape: rowMajorShape,
dataType: .float32
)
let rowMajorPointer = rowMajorBuffer.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(rowMajorPointer, 1, vDSP_Length(totalWeightCount))
chunkSpeakerWeights.enumerated().forEach { frameIndex, weights in
let destination = rowMajorPointer.advanced(by: frameIndex * speakerCount)
weights.withUnsafeBufferPointer { rowPointer in
guard let rowPtrBase = rowPointer.baseAddress else { return }
destination.update(from: rowPtrBase, count: speakerCount)
}
}
let transposedShape: [NSNumber] = [
NSNumber(value: speakerCount),
NSNumber(value: frameCount),
]
let transposedKey = "offline_embedding_transposed_\(speakerCount)_\(frameCount)"
let transposedBuffer = try memoryOptimizer.getPooledBuffer(
key: transposedKey,
shape: transposedShape,
dataType: .float32
)
let transposedPointer = transposedBuffer.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_mtrans(
rowMajorPointer,
1,
transposedPointer,
1,
vDSP_Length(frameCount),
vDSP_Length(speakerCount)
)
for speakerIndex in 0..<speakerCount {
evaluatedMaskCount += 1
let maskStart = clock.now
let columnOffset = speakerIndex * frameCount
baseMask.withUnsafeMutableBufferPointer { pointer in
guard let destBase = pointer.baseAddress else { return }
let columnPointer = transposedPointer.advanced(by: columnOffset)
destBase.update(from: columnPointer, count: frameCount)
}
let baseSum = VDSPOperations.sum(baseMask)
if baseSum <= 0 {
maskPreparationDuration += maskStart.duration(to: clock.now)
emptyMaskCount += 1
continue
}
cleanMask = baseMask
if config.embeddingExcludeOverlap {
for frame in 0..<frameCount where overlapFrames[frame] {
cleanMask[frame] = 0
}
}
let cleanSum = VDSPOperations.sum(cleanMask)
let maskToUse: [Float]
let maskSum: Float
if cleanSum >= Float(minFramesForEmbedding) {
maskToUse = cleanMask
maskSum = cleanSum
} else {
maskToUse = baseMask
maskSum = baseSum
fallbackMaskCount += 1
}
let maskPrepEnd = clock.now
maskPreparationDuration += maskStart.duration(to: maskPrepEnd)
if maskSum <= 0 {
emptyMaskCount += 1
continue
}
let resampleStart = maskPrepEnd
let resampledMask = WeightInterpolation.resample(maskToUse, to: weightFrameCount)
let maskEnergy = VDSPOperations.sum(resampledMask)
let resampleEnd = clock.now
resampleDuration += resampleStart.duration(to: resampleEnd)
if maskEnergy <= 0 {
emptyMaskCount += 1
continue
}
let embeddingStart = resampleEnd
let weightsArray = try prepareWeightsInput(weights: resampledMask)
let embedding256 = try runEmbeddingModel(
fbankFeatures: fbankFeatures,
weightsArray: weightsArray
)
let embeddingEnd = clock.now
embeddingDuration += embeddingStart.duration(to: embeddingEnd)
let firstActive = maskToUse.firstIndex(where: { $0 > overlapThreshold }) ?? 0
let lastActive = maskToUse.lastIndex(where: { $0 > overlapThreshold }) ?? firstActive
let startTime = info.chunkOffsetSeconds + Double(firstActive) * info.frameDuration
let endTime = info.chunkOffsetSeconds + Double(lastActive + 1) * info.frameDuration
processedMasks += 1
accumulatedMaskFrames += Double(maskSum)
pendingEmbeddings.append(embedding256)
pendingMetadata.append(
OfflineEmbeddingPending(
chunkIndex: info.chunkIndex,
speakerIndex: speakerIndex,
startFrame: firstActive,
endFrame: lastActive,
frameWeights: maskToUse,
startTime: startTime,
endTime: endTime,
embedding256: embedding256
)
)
if pendingEmbeddings.count == maxPLDABatch {
try await flushPending()
}
}
}
func flushFbankBatch() async throws {
guard !batchAudioInputs.isEmpty else { return }
let fbankStart = clock.now
let fbankOutputs = try runFbankBatch(audioArrays: batchAudioInputs)
fbankDuration += fbankStart.duration(to: clock.now)
fbankBatchCallCount += 1
guard fbankOutputs.count == batchInfos.count else {
throw OfflineDiarizationError.processingFailed(
"FBANK batch produced mismatched output count (\(fbankOutputs.count) vs \(batchInfos.count))"
)
}
for index in 0..<batchInfos.count {
try await processChunk(info: batchInfos[index], fbankFeatures: fbankOutputs[index])
}
batchAudioInputs.removeAll(keepingCapacity: true)
batchInfos.removeAll(keepingCapacity: true)
}
do {
try performEmbeddingWarmup()
} catch {
logger.debug("Embedding warmup skipped due to error: \(error.localizedDescription)")
}
for try await chunk in segmentationStream {
try Task.checkCancellation()
let chunkSpeakerWeights = chunk.speakerWeights
guard !chunkSpeakerWeights.isEmpty else { continue }
let frameDuration = resolveFrameDuration(chunk)
let chunkOffsetSeconds =
chunk.chunkOffsetSeconds.isFinite
? chunk.chunkOffsetSeconds
: Double(chunk.chunkIndex) * config.windowDuration
let estimatedStartSample = Int((chunkOffsetSeconds * Double(config.sampleRate)).rounded())
let clampedStartSample = max(0, min(estimatedStartSample, totalSamples))
let endSample = min(clampedStartSample + chunkSize, totalSamples)
guard clampedStartSample < endSample else {
continue
}
chunkBuffer.withUnsafeMutableBufferPointer { pointer in
vDSP_vclr(pointer.baseAddress!, 1, vDSP_Length(pointer.count))
}
try chunkBuffer.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
try audioSource.copySamples(
into: baseAddress,
offset: clampedStartSample,
count: chunkSize
)
}
let chunkLength = endSample - clampedStartSample
let fbankInput = try chunkBuffer.withUnsafeBufferPointer { pointer -> MLMultiArray in
guard let baseAddress = pointer.baseAddress else {
throw OfflineDiarizationError.processingFailed("Failed to access chunk buffer")
}
return try prepareFbankInput(
chunkPointer: baseAddress,
length: chunkLength
)
}
fbankWindowCount += 1
batchAudioInputs.append(fbankInput)
batchInfos.append(
OfflineChunkBatchInfo(
chunkIndex: chunk.chunkIndex,
chunkOffsetSeconds: chunkOffsetSeconds,
frameDuration: frameDuration,
speakerWeights: chunkSpeakerWeights
)
)
if batchAudioInputs.count == fbankBatchLimit {
try await flushFbankBatch()
}
}
try await flushFbankBatch()
try await flushPending()
if processedMasks > 0 {
let meanMaskFrames = accumulatedMaskFrames / Double(processedMasks)
let meanString = String(format: "%.2f", meanMaskFrames)
logger.debug(
"Embedding masks generated: \(embeddings.count) (meanActiveFrames=\(meanString), fallbackMasks=\(fallbackMaskCount), emptyMasks=\(emptyMaskCount))"
)
} else {
logger.debug("Embedding extractor produced no valid speaker masks")
}
if fbankWindowCount > 0 || evaluatedMaskCount > 0 || pldaOutputCount > 0 {
let fbankMs = Self.milliseconds(from: fbankDuration)
let maskPrepMs = Self.milliseconds(from: maskPreparationDuration)
let resampleMs = Self.milliseconds(from: resampleDuration)
let embeddingMs = Self.milliseconds(from: embeddingDuration)
let pldaMs = Self.milliseconds(from: pldaDuration)
let fbankPerWindow = fbankWindowCount > 0 ? fbankMs / Double(fbankWindowCount) : 0
let maskPrepPerEval = evaluatedMaskCount > 0 ? maskPrepMs / Double(evaluatedMaskCount) : 0
let resamplePerValid = processedMasks > 0 ? resampleMs / Double(processedMasks) : 0
let embeddingPerValid = processedMasks > 0 ? embeddingMs / Double(processedMasks) : 0
let pldaPerValid = processedMasks > 0 ? pldaMs / Double(processedMasks) : 0
let message =
"""
Embedding timings: fbankTotal=\(String(format: "%.2f", fbankMs))ms (perWindow=\(String(format: "%.3f", fbankPerWindow))ms), \
maskPrepTotal=\(String(format: "%.2f", maskPrepMs))ms (perEval=\(String(format: "%.3f", maskPrepPerEval))ms), \
resampleTotal=\(String(format: "%.2f", resampleMs))ms (perValid=\(String(format: "%.3f", resamplePerValid))ms), \
embeddingTotal=\(String(format: "%.2f", embeddingMs))ms (perValid=\(String(format: "%.3f", embeddingPerValid))ms), \
pldaTotal=\(String(format: "%.2f", pldaMs))ms (perValid=\(String(format: "%.3f", pldaPerValid))ms), \
batches(fbank=\(fbankBatchCallCount), plda=\(pldaBatchCallCount))
"""
logger.debug(message)
Self.emitProfileLog(message)
}
return embeddings
}
private func runFbankModel(
audioArray: MLMultiArray
) throws -> MLMultiArray {
guard let result = try runFbankBatch(audioArrays: [audioArray]).first else {
throw OfflineDiarizationError.processingFailed("FBANK model produced no output")
}
return result
}
private func runFbankBatch(
audioArrays: [MLMultiArray]
) throws -> [MLMultiArray] {
guard !audioArrays.isEmpty else { return [] }
var providers: [MLFeatureProvider] = []
providers.reserveCapacity(audioArrays.count)
for array in audioArrays {
providers.append(
ZeroCopyDiarizerFeatureProvider(
features: [
fbankInputName: MLFeatureValue(multiArray: array)
]
)
)
}
let options = MLPredictionOptions()
if #available(macOS 14.0, iOS 17.0, *) {
for array in audioArrays {
array.prefetchToNeuralEngine()
}
}
let batchProvider = MLArrayBatchProvider(array: providers)
let outputBatch = try fbankModel.predictions(from: batchProvider, options: options)
guard outputBatch.count == audioArrays.count else {
throw OfflineDiarizationError.processingFailed(
"FBANK batch produced \(outputBatch.count) outputs for \(audioArrays.count) inputs"
)
}
var results: [MLMultiArray] = []
results.reserveCapacity(outputBatch.count)
for index in 0..<outputBatch.count {
guard
let featureArray = outputBatch.features(at: index)
.featureValue(for: fbankOutputName)?.multiArrayValue
else {
throw OfflineDiarizationError.processingFailed(
"FBANK model missing \(fbankOutputName) output at batch index \(index)"
)
}
results.append(featureArray)
}
return results
}
private func prepareFbankInput(
chunkPointer: UnsafePointer<Float>,
length: Int
) throws -> MLMultiArray {
let array = try memoryOptimizer.createAlignedArray(
shape: fbankInputShape,
dataType: .float32
)
let pointer = array.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(pointer, 1, vDSP_Length(array.count))
let copyCount = min(length, audioSampleCount, array.count)
if copyCount > 0 {
vDSP_mmov(
chunkPointer,
pointer,
vDSP_Length(copyCount),
1,
vDSP_Length(copyCount),
1
)
}
return array
}
private func prepareWeightsInput(
weights: [Float]
) throws -> MLMultiArray {
let array = try memoryOptimizer.getPooledBuffer(
key: "offline_embedding_weights_\(weightFrameCount)",
shape: weightInputShape,
dataType: .float32
)
let pointer = array.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(pointer, 1, vDSP_Length(array.count))
let copyCount = min(weights.count, weightFrameCount, array.count)
if copyCount > 0 {
weights.withUnsafeBufferPointer { buffer in
vDSP_mmov(
buffer.baseAddress!,
pointer,
vDSP_Length(copyCount),
1,
vDSP_Length(copyCount),
1
)
}
}
return array
}
private func runEmbeddingModel(
fbankFeatures: MLMultiArray,
weightsArray: MLMultiArray
) throws -> [Float] {
let provider = ZeroCopyDiarizerFeatureProvider(
features: [
fbankFeatureName: MLFeatureValue(multiArray: fbankFeatures),
weightInputName: MLFeatureValue(multiArray: weightsArray),
]
)
let options = MLPredictionOptions()
if #available(macOS 14.0, iOS 17.0, *) {
fbankFeatures.prefetchToNeuralEngine()
weightsArray.prefetchToNeuralEngine()
}
let output = try embeddingModel.prediction(from: provider, options: options)
guard let embeddingArray = output.featureValue(for: embeddingOutputName)?.multiArrayValue else {
throw OfflineDiarizationError.processingFailed("Embedding model missing \(embeddingOutputName) output")
}
let pointer = embeddingArray.dataPointer.assumingMemoryBound(to: Float.self)
return Array(UnsafeBufferPointer(start: pointer, count: embeddingArray.count))
}
private static func milliseconds(from duration: Duration) -> Double {
let components = duration.components
let secondsMs = Double(components.seconds) * 1_000
let attosecondsMs = Double(components.attoseconds) / 1_000_000_000_000_000.0
return secondsMs + attosecondsMs
}
private static func emitProfileLog(_ message: String) {
let line = "[Profiling] \(message)\n"
if let data = line.data(using: .utf8) {
FileHandle.standardError.write(data)
}
}
private static func resolveElementCount(
from constraint: MLMultiArrayConstraint?,
fallback: Int
) -> Int {
guard let shape = constraint?.shape, !shape.isEmpty else {
return fallback
}
if let last = shape.last {
let value = last.intValue
if value > 0 {
return value
}
}
if let secondLast = shape.dropLast().last {
let value = secondLast.intValue
if value > 0 {
return value
}
}
return fallback
}
private static func sanitizedShape(
from constraint: MLMultiArrayConstraint,
fallback: [Int]
) -> [NSNumber] {
if let enumerated = constraint.shapeConstraint.enumeratedShapes.first, !enumerated.isEmpty {
return sanitizedShape(enumerated, fallback: fallback)
}
let explicitShape = constraint.shape
if !explicitShape.isEmpty {
return sanitizedShape(explicitShape, fallback: fallback)
}
let ranges = constraint.shapeConstraint.sizeRangeForDimension
if !ranges.isEmpty {
var sanitized: [NSNumber] = []
sanitized.reserveCapacity(ranges.count)
for (index, rangeValue) in ranges.enumerated() {
let range = rangeValue.rangeValue
let fallbackValue = fallbackValue(fallback, index: index)
let candidate = range.location > 0 ? range.location : fallbackValue
sanitized.append(NSNumber(value: max(1, candidate)))
}
return sanitized
}
return fallback.map { NSNumber(value: max(1, $0)) }
}
private static func sanitizedShape(
_ shape: [NSNumber],
fallback: [Int]
) -> [NSNumber] {
guard !shape.isEmpty else {
return fallback.map { NSNumber(value: max(1, $0)) }
}
var sanitized: [NSNumber] = []
sanitized.reserveCapacity(shape.count)
for (index, dimension) in shape.enumerated() {
let value = dimension.intValue
if value > 0 {
sanitized.append(dimension)
} else {
let fallbackValue = fallbackValue(fallback, index: index)
sanitized.append(NSNumber(value: max(1, fallbackValue)))
}
}
return sanitized
}
private static func fallbackValue(
_ fallback: [Int],
index: Int
) -> Int {
guard !fallback.isEmpty else {
return 1
}
if index < fallback.count {
return fallback[index]
}
return fallback.last ?? 1
}
private static func defaultShape(
dimensionHint: Int,
minimumCount: Int,
lastDimension: Int
) -> [Int] {
let count = max(max(dimensionHint, minimumCount), 1)
var shape = Array(repeating: 1, count: count)
shape[count - 1] = max(1, lastDimension)
return shape
}
private static func dimensionHint(for constraint: MLMultiArrayConstraint) -> Int {
if let enumerated = constraint.shapeConstraint.enumeratedShapes.first, !enumerated.isEmpty {
return enumerated.count
}
let explicitCount = constraint.shape.count
if explicitCount > 0 {
return explicitCount
}
let rangeCount = constraint.shapeConstraint.sizeRangeForDimension.count
if rangeCount > 0 {
return rangeCount
}
return 0
}
}
@@ -0,0 +1,432 @@
import Accelerate
import Foundation
struct OfflineReconstruction {
private let config: OfflineDiarizerConfig
private let logger = AppLogger(category: "OfflineReconstruction")
private struct Accumulator {
var start: Double
var end: Double
var scoreSum: Double
var frameCount: Int
}
init(config: OfflineDiarizerConfig) {
self.config = config
}
func buildSegments(
segmentation: SegmentationOutput,
hardClusters: [[Int]],
centroids: [[Double]]
) -> [TimedSpeakerSegment] {
guard segmentation.numChunks > 0, segmentation.numFrames > 0 else { return [] }
let frameDuration = segmentation.frameDuration
guard frameDuration > 0 else { return [] }
let clusterCount = max(centroids.count, 1)
let gapThreshold = max(config.minGapDuration, config.segmentationMinDurationOff)
var maxTime = 0.0
for chunkIndex in 0..<segmentation.numChunks {
let offset = chunkStartTime(for: chunkIndex, segmentation: segmentation)
let end = offset + Double(segmentation.numFrames) * frameDuration
if end > maxTime {
maxTime = end
}
}
let totalFrames = max(1, Int(ceil(maxTime / frameDuration)))
var activationSums = Array(
repeating: Array(repeating: 0.0, count: clusterCount),
count: totalFrames
)
var activationCounts = Array(
repeating: Array(repeating: 0.0, count: clusterCount),
count: totalFrames
)
var expectedCountSums = [Double](repeating: 0, count: totalFrames)
var expectedCountWeights = [Double](repeating: 0, count: totalFrames)
for chunkIndex in 0..<segmentation.numChunks {
guard chunkIndex < segmentation.speakerWeights.count else { continue }
let chunkWeights = segmentation.speakerWeights[chunkIndex]
guard !chunkWeights.isEmpty else { continue }
let chunkOffset = chunkStartTime(for: chunkIndex, segmentation: segmentation)
let chunkAssignments =
chunkIndex < hardClusters.count
? hardClusters[chunkIndex] : Array(repeating: -2, count: segmentation.numSpeakers)
for frameIndex in 0..<chunkWeights.count {
let frameStart = chunkOffset + Double(frameIndex) * frameDuration
var globalFrame = Int((frameStart / frameDuration).rounded())
if globalFrame < 0 {
globalFrame = 0
} else if globalFrame >= totalFrames {
globalFrame = totalFrames - 1
}
let weights = chunkWeights[frameIndex]
var frameActivations = [Double](repeating: 0, count: clusterCount)
for speakerIndex in 0..<min(weights.count, chunkAssignments.count) {
let cluster = chunkAssignments[speakerIndex]
guard cluster >= 0, cluster < clusterCount else { continue }
let value = Double(weights[speakerIndex])
if value > frameActivations[cluster] {
frameActivations[cluster] = value
}
}
let expectedCount = weights.reduce(0.0) { partialSum, value in
partialSum + Double(value)
}
expectedCountSums[globalFrame] += expectedCount
expectedCountWeights[globalFrame] += 1
for cluster in 0..<clusterCount {
let value = frameActivations[cluster]
if value > 0 {
activationSums[globalFrame][cluster] += value
activationCounts[globalFrame][cluster] += 1
}
}
}
}
var activationAverages = Array(
repeating: Array(repeating: 0.0, count: clusterCount),
count: totalFrames
)
for frame in 0..<totalFrames {
let sums = activationSums[frame]
let counts = activationCounts[frame]
var averages = [Double](repeating: 0, count: clusterCount)
// Vectorized division: averages = sums / counts (where counts > 0)
sums.withUnsafeBufferPointer { sumsPtr in
counts.withUnsafeBufferPointer { countsPtr in
averages.withUnsafeMutableBufferPointer { averagesPtr in
guard let sumsBase = sumsPtr.baseAddress,
let countsBase = countsPtr.baseAddress,
let averagesBase = averagesPtr.baseAddress
else { return }
vDSP_vdivD(
countsBase,
1,
sumsBase,
1,
averagesBase,
1,
vDSP_Length(clusterCount)
)
}
}
}
// Zero out results where count was 0 (to avoid division by zero artifacts)
for cluster in 0..<clusterCount where counts[cluster] == 0 {
averages[cluster] = 0
}
activationAverages[frame] = averages
}
var speakerCountPerFrame = [Int](repeating: 0, count: totalFrames)
var speakerCountHistogram: [Int: Int] = [:]
let maxAllowedSpeakers = min(clusterCount, segmentation.numSpeakers)
for frame in 0..<totalFrames {
let weight = expectedCountWeights[frame]
guard weight > 0 else { continue }
var rounded = Int((expectedCountSums[frame] / weight).rounded(.toNearestOrEven))
if rounded < 0 { rounded = 0 }
if rounded > maxAllowedSpeakers { rounded = maxAllowedSpeakers }
speakerCountPerFrame[frame] = rounded
speakerCountHistogram[rounded, default: 0] += 1
}
if !speakerCountHistogram.isEmpty {
let histogramString =
speakerCountHistogram
.sorted { $0.key < $1.key }
.map { "\($0.key):\($0.value)" }
.joined(separator: ", ")
logger.debug("Speaker-count histogram \(histogramString)")
}
var perFrameClusters = [[Int]](repeating: [], count: totalFrames)
for frame in 0..<totalFrames {
let required = speakerCountPerFrame[frame]
guard required > 0 else { continue }
let ranked = activationSums[frame].enumerated().sorted { $0.element > $1.element }
let selected = ranked.prefix(required).map { $0.offset }
perFrameClusters[frame] = selected
}
var activeSegments: [Int: Accumulator] = [:]
var rawSegments: [TimedSpeakerSegment] = []
for frameIndex in 0..<totalFrames {
let frameStart = Double(frameIndex) * frameDuration
let frameEnd = frameStart + frameDuration
let activeClusters = Set(perFrameClusters[frameIndex])
let averageScores = activationAverages[frameIndex]
for (cluster, accumulator) in activeSegments where !activeClusters.contains(cluster) {
appendSegment(
cluster: cluster,
accumulator: accumulator,
endTime: frameStart,
centroids: centroids,
output: &rawSegments
)
}
activeSegments = activeSegments.filter { activeClusters.contains($0.key) }
for cluster in activeClusters {
let score = averageScores.indices.contains(cluster) ? averageScores[cluster] : 0
if var existing = activeSegments[cluster] {
existing.end = frameEnd
existing.scoreSum += score
existing.frameCount += 1
activeSegments[cluster] = existing
} else {
activeSegments[cluster] = Accumulator(
start: frameStart,
end: frameEnd,
scoreSum: score,
frameCount: 1
)
}
}
}
for (cluster, accumulator) in activeSegments {
appendSegment(
cluster: cluster,
accumulator: accumulator,
endTime: accumulator.end,
centroids: centroids,
output: &rawSegments
)
}
let merged = mergeSegments(rawSegments, gapThreshold: gapThreshold)
return sanitize(segments: merged)
}
func buildSpeakerDatabase(
segments: [TimedSpeakerSegment]
) -> [String: [Float]] {
var sums: [String: [Float]] = [:]
var counts: [String: Int] = [:]
for segment in segments {
if var current = sums[segment.speakerId] {
let embedding = segment.embedding
precondition(
embedding.count == current.count,
"Embedding dimensionality mismatch while accumulating speaker database"
)
let countIndex = makeBlasIndexOrFatal(embedding.count, label: "speaker embedding length")
let unitStride = BlasIndex(1)
embedding.withUnsafeBufferPointer { sourcePointer in
current.withUnsafeMutableBufferPointer { destinationPointer in
guard
let sourceBase = sourcePointer.baseAddress,
let destinationBase = destinationPointer.baseAddress
else { return }
cblas_saxpy(
countIndex,
1.0,
sourceBase,
unitStride,
destinationBase,
unitStride
)
}
}
sums[segment.speakerId] = current
} else {
sums[segment.speakerId] = segment.embedding
}
counts[segment.speakerId, default: 0] += 1
}
var database: [String: [Float]] = [:]
for (speaker, sum) in sums {
guard let count = counts[speaker], count > 0 else { continue }
var averaged = sum
var scale = 1 / Float(count)
let length = averaged.count
averaged.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
vDSP_vsmul(
baseAddress,
1,
&scale,
baseAddress,
1,
vDSP_Length(length)
)
}
database[speaker] = averaged
}
return database
}
private func excludeOverlaps(in segments: [TimedSpeakerSegment]) -> [TimedSpeakerSegment] {
guard !segments.isEmpty else { return [] }
var sanitized: [TimedSpeakerSegment] = []
for segment in segments {
var adjustedStart = segment.startTimeSeconds
let adjustedEnd = segment.endTimeSeconds
if let previous = sanitized.last {
if adjustedStart < previous.endTimeSeconds {
adjustedStart = previous.endTimeSeconds
}
}
if adjustedStart >= adjustedEnd {
continue
}
let duration = adjustedEnd - adjustedStart
if duration < Float(config.minSegmentDuration) {
continue
}
let originalDuration = segment.endTimeSeconds - segment.startTimeSeconds
let qualityScale = originalDuration > 0 ? duration / originalDuration : 1
let adjustedQuality = max(0, min(1, segment.qualityScore * qualityScale))
let trimmed = TimedSpeakerSegment(
speakerId: segment.speakerId,
embedding: segment.embedding,
startTimeSeconds: adjustedStart,
endTimeSeconds: adjustedEnd,
qualityScore: adjustedQuality
)
sanitized.append(trimmed)
}
return sanitized
}
private func appendSegment(
cluster: Int,
accumulator: Accumulator,
endTime: Double,
centroids: [[Double]],
output: inout [TimedSpeakerSegment]
) {
guard endTime > accumulator.start else { return }
let averageScore: Double
if accumulator.frameCount > 0 {
averageScore = accumulator.scoreSum / Double(accumulator.frameCount)
} else {
averageScore = accumulator.scoreSum
}
let quality = Float(min(max(averageScore, 0), 1))
let centroidDouble =
centroids.indices.contains(cluster)
? centroids[cluster]
: Array(repeating: 0, count: centroids.first?.count ?? 0)
let centroid = centroidDouble.map { Float($0) }
let segment = TimedSpeakerSegment(
speakerId: "S\(cluster + 1)",
embedding: centroid,
startTimeSeconds: Float(accumulator.start),
endTimeSeconds: Float(endTime),
qualityScore: quality
)
output.append(segment)
}
private func mergeSegments(
_ segments: [TimedSpeakerSegment],
gapThreshold: Double
) -> [TimedSpeakerSegment] {
guard !segments.isEmpty else { return [] }
let sorted = segments.sorted { $0.startTimeSeconds < $1.startTimeSeconds }
var merged: [TimedSpeakerSegment] = []
var current = sorted[0]
for segment in sorted.dropFirst() {
if segment.speakerId == current.speakerId {
let gap = Double(segment.startTimeSeconds) - Double(current.endTimeSeconds)
if gap <= gapThreshold {
let blended = blendedQuality(current, segment)
current = TimedSpeakerSegment(
speakerId: current.speakerId,
embedding: current.embedding,
startTimeSeconds: current.startTimeSeconds,
endTimeSeconds: max(current.endTimeSeconds, segment.endTimeSeconds),
qualityScore: blended
)
continue
}
}
merged.append(current)
current = segment
}
merged.append(current)
return merged
}
private func blendedQuality(_ lhs: TimedSpeakerSegment, _ rhs: TimedSpeakerSegment) -> Float {
let lhsDuration = Double(lhs.durationSeconds)
let rhsDuration = Double(rhs.durationSeconds)
let totalDuration = lhsDuration + rhsDuration
guard totalDuration > 0 else {
return min(max((lhs.qualityScore + rhs.qualityScore) / 2, 0), 1)
}
let weighted =
Double(lhs.qualityScore) * lhsDuration
+ Double(rhs.qualityScore) * rhsDuration
return Float(min(max(weighted / totalDuration, 0), 1))
}
private func sanitize(segments: [TimedSpeakerSegment]) -> [TimedSpeakerSegment] {
var ordered = segments.sorted { $0.startTimeSeconds < $1.startTimeSeconds }
let minimumDuration = max(
Float(config.minSegmentDuration),
Float(config.segmentationMinDurationOn)
)
ordered = ordered.filter {
($0.endTimeSeconds - $0.startTimeSeconds) >= minimumDuration
}
if config.embeddingExcludeOverlap {
ordered = excludeOverlaps(in: ordered)
}
return ordered
}
private func chunkStartTime(
for chunkIndex: Int,
segmentation: SegmentationOutput
) -> Double {
if segmentation.chunkOffsets.indices.contains(chunkIndex) {
return segmentation.chunkOffsets[chunkIndex]
} else {
return Double(chunkIndex) * config.windowDuration
}
}
}
@@ -0,0 +1,580 @@
import Accelerate
import CoreML
import Foundation
import OSLog
import os.signpost
struct OfflineSegmentationProcessor {
private let logger = AppLogger(category: "OfflineSegmentation")
private let signposter = OSSignposter(
subsystem: "com.fluidaudio.diarization",
category: .pointsOfInterest
)
private let memoryOptimizer = ANEMemoryOptimizer()
private let powerset: [[Int]] = [
[],
[0],
[1],
[2],
[0, 1],
[0, 2],
[1, 2],
[0, 1, 2],
]
func process(
audioSamples: [Float],
segmentationModel: MLModel,
config: OfflineDiarizerConfig,
chunkHandler: SegmentationChunkHandler? = nil
) async throws -> SegmentationOutput {
guard !audioSamples.isEmpty else {
throw OfflineDiarizationError.noSpeechDetected
}
return try await process(
audioSource: ArrayAudioSampleSource(samples: audioSamples),
segmentationModel: segmentationModel,
config: config,
chunkHandler: chunkHandler
)
}
func process(
audioSource: StreamingAudioSampleSource,
segmentationModel: MLModel,
config: OfflineDiarizerConfig,
chunkHandler: SegmentationChunkHandler? = nil
) async throws -> SegmentationOutput {
let totalSamples = audioSource.sampleCount
guard totalSamples > 0 else {
throw OfflineDiarizationError.noSpeechDetected
}
let chunkSize = config.samplesPerWindow
let stepSize = config.samplesPerStep
var logProbChunks: [[[Float]]] = []
var weightChunks: [[[Float]]] = []
var chunkOffsets: [Double] = []
var frameDuration: Double = 0
var numFrames = 0
let speakerCount = 3
let speakerClassIndices: [[Int]] = (0..<speakerCount).map { speaker in
powerset.enumerated().compactMap { index, combination in
combination.contains(speaker) ? index : nil
}
}
// Pre-compute flat mapping matrix for vectorized speaker activation
// Matrix[speaker][class] = 1.0 if speaker in powerset[class], else 0.0
let speakerToClassMapping: [[Float]] = (0..<speakerCount).map { speaker in
powerset.map { combination in
combination.contains(speaker) ? Float(1.0) : Float(0.0)
}
}
var classHistogram = Array(repeating: 0, count: powerset.count)
var classProbabilitySums = Array(repeating: Float.zero, count: powerset.count)
let chunkCallback = chunkHandler
var chunkEmissionEnabled = chunkCallback != nil
logger.debug(
"Offline segmentation: chunkSize=\(chunkSize), stepSize=\(stepSize), totalSamples=\(totalSamples)"
)
var speechFrameCount = 0
var winningProbabilitySum: Double = 0
var winningProbabilityCount = 0
var winningProbabilityMin: Float = 1
var winningProbabilityMax: Float = 0
var emptyClassProbabilitySum: Double = 0
var emptyClassProbabilityCount = 0
let probabilityThresholds: [Float] = [0.50, 0.70, 0.80, 0.90, 0.95, 0.98, 0.99, 0.995, 0.999]
var probabilityThresholdCounts = Array(repeating: 0, count: probabilityThresholds.count)
let emptyClassIndex = 0
let onsetThreshold = config.speechOnsetThreshold
let batchCapacity = 32
var globalChunkIndex = 0
let clock = ContinuousClock()
var prepareDuration: Duration = .zero
var predictionDuration: Duration = .zero
var preparedWindowCount = 0
var slidingWindow = [Float](repeating: 0, count: chunkSize)
var previousOffset: Int?
let reuseEnabled = stepSize < chunkSize
@Sendable
func performWarmup() async throws {
let warmupShape: [NSNumber] = [1, 1, NSNumber(value: chunkSize)]
let warmupKey = "offline_segmentation_warmup_\(chunkSize)"
let warmupArray = try memoryOptimizer.getPooledBuffer(
key: warmupKey,
shape: warmupShape,
dataType: .float32
)
let warmupPointer = warmupArray.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(warmupPointer, 1, vDSP_Length(chunkSize))
let warmupProvider = ZeroCopyDiarizerFeatureProvider(
features: ["audio": MLFeatureValue(multiArray: warmupArray)]
)
let warmupOptions = MLPredictionOptions()
if #available(macOS 14.0, iOS 17.0, *) {
warmupArray.prefetchToNeuralEngine()
}
_ = try await segmentationModel.prediction(from: warmupProvider, options: warmupOptions)
}
func populateWindow(
destination: UnsafeMutablePointer<Float>,
offset: Int
) throws {
let availableForWindow = max(0, min(chunkSize, totalSamples - offset))
if reuseEnabled,
let lastOffset = previousOffset,
offset == lastOffset + stepSize
{
try slidingWindow.withUnsafeMutableBufferPointer { pointer in
guard let base = pointer.baseAddress else { return }
let reuseCount = max(0, chunkSize - stepSize)
if reuseCount > 0 {
memmove(
base,
base.advanced(by: stepSize),
reuseCount * MemoryLayout<Float>.stride
)
}
let samplesNeeded = chunkSize - reuseCount
if samplesNeeded > 0 {
let tailOffset = offset + reuseCount
let available = max(
0,
min(samplesNeeded, totalSamples - tailOffset)
)
if available > 0 {
try audioSource.copySamples(
into: base.advanced(by: reuseCount),
offset: tailOffset,
count: available
)
}
if available < samplesNeeded {
vDSP_vclr(
base.advanced(by: reuseCount + available),
1,
vDSP_Length(samplesNeeded - available)
)
}
}
}
} else {
try slidingWindow.withUnsafeMutableBufferPointer { pointer in
guard let base = pointer.baseAddress else { return }
if availableForWindow > 0 {
try audioSource.copySamples(
into: base,
offset: offset,
count: availableForWindow
)
}
if availableForWindow < chunkSize {
vDSP_vclr(
base.advanced(by: availableForWindow),
1,
vDSP_Length(chunkSize - availableForWindow)
)
}
}
}
slidingWindow.withUnsafeBufferPointer { pointer in
guard let base = pointer.baseAddress else { return }
destination.update(from: base, count: chunkSize)
}
previousOffset = offset
}
var processedAnyBatch = false
var offsetIterator = stride(from: 0, to: totalSamples, by: stepSize).makeIterator()
var batchOffsets: [Int] = []
batchOffsets.reserveCapacity(batchCapacity)
do {
try await performWarmup()
} catch {
logger.debug("Segmentation warmup skipped due to error: \(error.localizedDescription)")
}
while true {
try Task.checkCancellation()
batchOffsets.removeAll(keepingCapacity: true)
for _ in 0..<batchCapacity {
guard let offset = offsetIterator.next() else { break }
batchOffsets.append(offset)
}
if batchOffsets.isEmpty {
break
}
processedAnyBatch = true
let batchCount = batchOffsets.count
let shape: [NSNumber] = [
NSNumber(value: batchCount),
1,
NSNumber(value: chunkSize),
]
let bufferKey = "offline_segmentation_audio_\(batchCount)_\(chunkSize)"
let audioArray = try memoryOptimizer.getPooledBuffer(
key: bufferKey,
shape: shape,
dataType: .float32
)
let ptr = audioArray.dataPointer.assumingMemoryBound(to: Float.self)
let prepareStart = clock.now
for (localIndex, offset) in batchOffsets.enumerated() {
let destination = ptr.advanced(by: localIndex * chunkSize)
try populateWindow(destination: destination, offset: offset)
}
prepareDuration += prepareStart.duration(to: clock.now)
preparedWindowCount += batchOffsets.count
let provider = ZeroCopyDiarizerFeatureProvider(
features: ["audio": MLFeatureValue(multiArray: audioArray)]
)
let options = MLPredictionOptions()
if #available(macOS 14.0, iOS 17.0, *) {
audioArray.prefetchToNeuralEngine()
}
let predictionState = signposter.beginInterval("Segmentation Model Prediction")
try Task.checkCancellation()
let predictionStart = clock.now
let output = try await segmentationModel.prediction(from: provider, options: options)
predictionDuration += predictionStart.duration(to: clock.now)
signposter.endInterval("Segmentation Model Prediction", predictionState)
let logitsArray: MLMultiArray
if let segments = output.featureValue(for: "segments")?.multiArrayValue {
logitsArray = segments
} else if let logProbs = output.featureValue(for: "log_probs")?.multiArrayValue {
logitsArray = logProbs
} else if let fallback = output.featureNames.compactMap({ name -> MLMultiArray? in
output.featureValue(for: name)?.multiArrayValue
}).first {
logitsArray = fallback
} else {
let available = Array(output.featureNames)
throw OfflineDiarizationError.processingFailed(
"Segmentation model missing expected multiarray output. Available: \(available)"
)
}
let logitsShape = logitsArray.shape.map { $0.intValue }
let (batchSize, frames, classes): (Int, Int, Int)
switch logitsShape.count {
case 3:
batchSize = logitsShape[0]
frames = logitsShape[1]
classes = logitsShape[2]
case 2:
batchSize = 1
frames = logitsShape[0]
classes = logitsShape[1]
default:
throw OfflineDiarizationError.processingFailed(
"Unexpected segmentation output shape \(logitsShape)"
)
}
frameDuration = config.windowDuration / Double(frames)
numFrames = frames
if classes > powerset.count {
logger.error(
"Segmentation model returned \(classes) classes but only \(powerset.count) powerset entries available"
)
}
let logitsPointer = logitsArray.dataPointer.assumingMemoryBound(to: Float.self)
for localIndex in 0..<batchCount {
if localIndex >= batchSize {
break
}
let offset = batchOffsets[localIndex]
chunkOffsets.append(Double(offset) / Double(config.sampleRate))
var chunkLogProbs = Array(
repeating: Array(repeating: Float.zero, count: classes),
count: frames
)
var chunkSpeakerProbs = Array(
repeating: Array(repeating: Float.zero, count: speakerCount),
count: frames
)
let baseIndex = localIndex * frames * classes
var frameLogits = [Float](repeating: 0, count: classes)
var logProbabilityBuffer = [Float](repeating: 0, count: classes)
var probabilityBuffer = [Float](repeating: 0, count: classes)
for frameIndex in 0..<frames {
let start = baseIndex + frameIndex * classes
frameLogits.withUnsafeMutableBufferPointer { destination in
destination.baseAddress!.update(from: logitsPointer.advanced(by: start), count: classes)
}
var bestIndex = 0
var bestValue = -Float.greatestFiniteMagnitude
for cls in 0..<classes {
let value = frameLogits[cls]
if value > bestValue {
bestValue = value
bestIndex = cls
}
}
let logSumExp = VDSPOperations.logSumExp(frameLogits)
var shift = -logSumExp
vDSP_vsadd(
frameLogits,
1,
&shift,
&logProbabilityBuffer,
1,
vDSP_Length(classes)
)
probabilityBuffer = logProbabilityBuffer
probabilityBuffer.withUnsafeMutableBufferPointer { pointer in
var count = Int32(classes)
vvexpf(pointer.baseAddress!, pointer.baseAddress!, &count)
}
chunkLogProbs[frameIndex].withUnsafeMutableBufferPointer { destination in
logProbabilityBuffer.withUnsafeBufferPointer { source in
destination.baseAddress!.update(from: source.baseAddress!, count: classes)
}
}
for cls in 0..<min(classes, classProbabilitySums.count) {
classProbabilitySums[cls] += probabilityBuffer[cls]
}
if bestIndex < classHistogram.count {
classHistogram[bestIndex] += 1
}
let winningClass = min(bestIndex, powerset.count - 1)
let winningSpeakers = powerset[winningClass].filter { $0 < speakerCount }
let winningProbability = probabilityBuffer[winningClass]
let emptyProbability =
emptyClassIndex < probabilityBuffer.count ? probabilityBuffer[emptyClassIndex] : 0
if !winningSpeakers.isEmpty {
winningProbabilitySum += Double(winningProbability)
winningProbabilityCount += 1
if winningProbability < winningProbabilityMin {
winningProbabilityMin = winningProbability
}
if winningProbability > winningProbabilityMax {
winningProbabilityMax = winningProbability
}
emptyClassProbabilitySum += Double(emptyProbability)
emptyClassProbabilityCount += 1
for (index, threshold) in probabilityThresholds.enumerated() {
if winningProbability >= threshold {
probabilityThresholdCounts[index] += 1
}
}
}
// Vectorized speaker activation using matrix-vector multiply
// speakerActivations[speaker] = sum of probabilityBuffer[class] where speaker in powerset[class]
// Handle case where model outputs fewer classes than powerset entries (e.g., 7 vs 8)
let paddedProbabilityBuffer: [Float]
if probabilityBuffer.count < powerset.count {
paddedProbabilityBuffer =
probabilityBuffer
+ [Float](
repeating: 0,
count: powerset.count - probabilityBuffer.count
)
} else {
paddedProbabilityBuffer = Array(probabilityBuffer.prefix(powerset.count))
}
let speakerActivations = VDSPOperations.matrixVectorMultiply(
matrix: speakerToClassMapping,
vector: paddedProbabilityBuffer
).map { min(max($0, 0), 1) }
chunkSpeakerProbs[frameIndex] = speakerActivations
let speechProbability = max(0, min(1, 1 - emptyProbability))
if speechProbability >= onsetThreshold {
speechFrameCount += 1
}
}
var chunkWeights = Array(
repeating: Array(repeating: Float.zero, count: speakerCount),
count: frames
)
// Pyannote community-1 powerset models provide powerset probabilities that we marginalize
// into per-speaker activity weights for each frame (0...1).
for frameIndex in 0..<frames {
chunkWeights[frameIndex] = chunkSpeakerProbs[frameIndex]
}
logProbChunks.append(chunkLogProbs)
weightChunks.append(chunkWeights)
if chunkEmissionEnabled, let chunkCallback {
let chunkOffsetSeconds = chunkOffsets.last ?? Double(offset) / Double(config.sampleRate)
let chunk = SegmentationChunk(
chunkIndex: globalChunkIndex,
chunkOffsetSeconds: chunkOffsetSeconds,
frameDuration: frameDuration,
logProbs: chunkLogProbs,
speakerWeights: chunkWeights
)
if chunkCallback(chunk) == .stop {
chunkEmissionEnabled = false
}
}
if globalChunkIndex == 0 {
let speakerCoverage = chunkSpeakerProbs.reduce(into: Array(repeating: 0, count: speakerCount)) {
counts, frame in
for (index, probability) in frame.enumerated() where probability >= onsetThreshold {
counts[index] += 1
}
}
logger.debug("Chunk 0 speaker frame counts: \(speakerCoverage)")
}
globalChunkIndex += 1
}
}
guard processedAnyBatch else {
throw OfflineDiarizationError.processingFailed("Segmentation produced no analysis windows")
}
let totalFrames = classHistogram.reduce(0, +)
if totalFrames > 0 {
let speechFrames = totalFrames - classHistogram[0]
let speechRatio = Double(speechFrames) / Double(totalFrames)
let nonSpeechProb =
classProbabilitySums[0] / Float(totalFrames == 0 ? 1 : totalFrames)
logger.debug(
"""
Segmentation histogram: speechFrames=\(speechFrames) totalFrames=\(totalFrames) \
speechRatio=\(String(format: "%.3f", speechRatio)) avgNonSpeechProb=\(String(format: "%.3f", nonSpeechProb))
"""
)
}
let totalFramesWithSpeech = speechFrameCount
let totalFramesOverall = numFrames * logProbChunks.count
if totalFramesOverall > 0 {
let ratio = Double(totalFramesWithSpeech) / Double(totalFramesOverall)
let ratioString = String(format: "%.3f", ratio)
let predictedDuration = Double(totalFramesWithSpeech) * frameDuration
let durationString = String(format: "%.1f", predictedDuration)
logger.debug(
"Segmentation mask speech frames = \(totalFramesWithSpeech) / \(totalFramesOverall) (ratio=\(ratioString), speechSeconds≈\(durationString)s)"
)
}
if winningProbabilityCount > 0 {
let averageWinning = winningProbabilitySum / Double(winningProbabilityCount)
logger.debug(
"""
Winning speaker probability stats: count=\(winningProbabilityCount), \
avg=\(String(format: "%.3f", averageWinning)), \
min=\(String(format: "%.3f", winningProbabilityMin)), \
max=\(String(format: "%.3f", winningProbabilityMax))
"""
)
var distribution: [String] = []
for (index, threshold) in probabilityThresholds.enumerated() {
let count = probabilityThresholdCounts[index]
let thresholdString = String(format: "%.3f", threshold)
distribution.append("\(thresholdString):\(count)")
}
let distributionString = distribution.joined(separator: ", ")
logger.debug("Winning probability distribution \(distributionString)")
}
if emptyClassProbabilityCount > 0 {
let averageEmpty = emptyClassProbabilitySum / Double(emptyClassProbabilityCount)
let averageEmptyString = String(format: "%.3f", averageEmpty)
logger.debug(
"Empty-class probability on speech frames: avg=\(averageEmptyString)"
)
}
if preparedWindowCount > 0 {
let prepareMs = Self.milliseconds(from: prepareDuration)
let predictionMs = Self.milliseconds(from: predictionDuration)
let preparePerWindow = prepareMs / Double(preparedWindowCount)
let predictionPerWindow = predictionMs / Double(preparedWindowCount)
let prepareTotalString = String(format: "%.2f", prepareMs)
let prepareWindowString = String(format: "%.4f", preparePerWindow)
let predictionTotalString = String(format: "%.2f", predictionMs)
let predictionWindowString = String(format: "%.4f", predictionPerWindow)
let message =
"""
Segmentation timings: windows=\(preparedWindowCount) \
prepareTotal=\(prepareTotalString)ms (perWindow=\(prepareWindowString)ms) \
predictionTotal=\(predictionTotalString)ms (perWindow=\(predictionWindowString)ms)
"""
logger.debug(message)
Self.emitProfileLog(message)
}
return SegmentationOutput(
logProbs: logProbChunks,
speakerWeights: weightChunks,
numChunks: logProbChunks.count,
numFrames: numFrames,
numSpeakers: speakerCount,
chunkOffsets: chunkOffsets,
frameDuration: frameDuration
)
}
}
extension OfflineSegmentationProcessor {
fileprivate static func milliseconds(from duration: Duration) -> Double {
let components = duration.components
let secondsMs = Double(components.seconds) * 1_000
let attosecondsMs = Double(components.attoseconds) / 1_000_000_000_000_000.0
return secondsMs + attosecondsMs
}
fileprivate static func emitProfileLog(_ message: String) {
let line = "[Profiling] \(message)\n"
if let data = line.data(using: .utf8) {
FileHandle.standardError.write(data)
}
}
}
@@ -0,0 +1,198 @@
import Accelerate
import CoreML
import Foundation
import OSLog
@available(macOS 14.0, iOS 17.0, *)
public struct PLDATransform {
private let pldaRhoModel: MLModel
private let psi: [Double]
private let memoryOptimizer = ANEMemoryOptimizer()
private let logger = AppLogger(category: "OfflinePLDA")
private let embeddingDimension = 256
private let rhoDimension = 128
private let maxBatchSize = 32
public init(pldaRhoModel: MLModel, psi: [Double]) {
self.pldaRhoModel = pldaRhoModel
self.psi = psi
}
public var phiParameters: [Double] { psi }
/// Transform a sequence of 256-dimensional embeddings into 128-dimensional
/// PLDA-space rho features using the Core ML model exported from Pyannote.
public func transform(_ embeddings: [[Float]]) async throws -> [[Double]] {
guard !embeddings.isEmpty else { return [] }
for embedding in embeddings {
guard embedding.count == embeddingDimension else {
throw OfflineDiarizationError.invalidConfiguration(
"Expected \(embeddingDimension)-dim embeddings, got \(embedding.count)"
)
}
}
var results: [[Double]] = []
results.reserveCapacity(embeddings.count)
do {
try await performWarmup()
} catch {
logger.debug("PLDA warmup skipped due to error: \(error.localizedDescription)")
}
var startIndex = 0
while startIndex < embeddings.count {
try Task.checkCancellation()
let endIndex = min(startIndex + maxBatchSize, embeddings.count)
let batch = embeddings[startIndex..<endIndex]
let batchResults = try await transformBatch(batch)
results.append(contentsOf: batchResults)
startIndex = endIndex
}
return results
}
/// Convenience wrapper for single-embedding transform.
public func transform(_ embedding: [Float]) async throws -> [Double] {
guard embedding.count == embeddingDimension else {
throw OfflineDiarizationError.invalidConfiguration(
"Expected \(embeddingDimension)-dim embedding, got \(embedding.count)"
)
}
let transformed = try await transform([embedding])
return transformed.first ?? []
}
/// Cosine similarity score between two rho vectors.
public func score(_ lhs: [Double], _ rhs: [Double]) -> Double {
guard lhs.count == rhoDimension, rhs.count == rhoDimension else {
return 0
}
var dot: Double = 0
var normLhs: Double = 0
var normRhs: Double = 0
lhs.withUnsafeBufferPointer { lhsPointer in
rhs.withUnsafeBufferPointer { rhsPointer in
vDSP_dotprD(
lhsPointer.baseAddress!,
1,
rhsPointer.baseAddress!,
1,
&dot,
vDSP_Length(rhoDimension)
)
vDSP_dotprD(
lhsPointer.baseAddress!,
1,
lhsPointer.baseAddress!,
1,
&normLhs,
vDSP_Length(rhoDimension)
)
vDSP_dotprD(
rhsPointer.baseAddress!,
1,
rhsPointer.baseAddress!,
1,
&normRhs,
vDSP_Length(rhoDimension)
)
}
}
let magnitude = sqrt(normLhs) * sqrt(normRhs)
if magnitude <= 0 {
return 0
}
return dot / magnitude
}
private func transformBatch(_ embeddings: ArraySlice<[Float]>) async throws -> [[Double]] {
guard !embeddings.isEmpty else { return [] }
guard embeddings.count <= maxBatchSize else {
throw OfflineDiarizationError.invalidBatchSize(
"PldaRho batch size must be <= \(maxBatchSize), got \(embeddings.count)"
)
}
let shape: [NSNumber] = [NSNumber(value: embeddings.count), NSNumber(value: embeddingDimension)]
let inputArray = try memoryOptimizer.createAlignedArray(shape: shape, dataType: .float32)
let pointer = inputArray.dataPointer.assumingMemoryBound(to: Float.self)
for (batchIndex, embedding) in embeddings.enumerated() {
let base = batchIndex * embeddingDimension
embedding.withUnsafeBufferPointer { buffer in
vDSP_mmov(
buffer.baseAddress!,
pointer.advanced(by: base),
vDSP_Length(embeddingDimension),
1,
vDSP_Length(embeddingDimension),
1
)
}
}
let provider = ZeroCopyDiarizerFeatureProvider(
features: ["embeddings": MLFeatureValue(multiArray: inputArray)]
)
let options = MLPredictionOptions()
inputArray.prefetchToNeuralEngine()
let output = try await pldaRhoModel.prediction(from: provider, options: options)
guard let rhoArray = output.featureValue(for: "rho")?.multiArrayValue else {
throw OfflineDiarizationError.processingFailed("PldaRho model did not produce rho output")
}
let rhoPointer = rhoArray.dataPointer.assumingMemoryBound(to: Float.self)
var results: [[Double]] = []
results.reserveCapacity(embeddings.count)
let totalRhoCount = embeddings.count * rhoDimension
var rhoScratch = [Double](repeating: 0, count: totalRhoCount)
let floatPointer = UnsafePointer<Float>(rhoPointer)
let sourceBuffer = UnsafeBufferPointer(start: floatPointer, count: totalRhoCount)
rhoScratch.withUnsafeMutableBufferPointer { dest in
guard let destBase = dest.baseAddress else { return }
var destinationBuffer = UnsafeMutableBufferPointer(start: destBase, count: totalRhoCount)
vDSP.convertElements(of: sourceBuffer, to: &destinationBuffer)
}
for batchIndex in 0..<embeddings.count {
let start = batchIndex * rhoDimension
let end = start + rhoDimension
let rhoSlice = Array(rhoScratch[start..<end])
results.append(rhoSlice)
}
return results
}
private func performWarmup() async throws {
let warmupShape: [NSNumber] = [1, NSNumber(value: embeddingDimension)]
let warmupKey = "offline_plda_warmup_embedding_\(embeddingDimension)"
let warmupArray = try memoryOptimizer.getPooledBuffer(
key: warmupKey,
shape: warmupShape,
dataType: .float32
)
let pointer = warmupArray.dataPointer.assumingMemoryBound(to: Float.self)
vDSP_vclr(pointer, 1, vDSP_Length(warmupArray.count))
let provider = ZeroCopyDiarizerFeatureProvider(
features: ["embeddings": MLFeatureValue(multiArray: warmupArray)]
)
let options = MLPredictionOptions()
warmupArray.prefetchToNeuralEngine()
_ = try await pldaRhoModel.prediction(from: provider, options: options)
}
}
@@ -0,0 +1,675 @@
import Accelerate
import Foundation
import OSLog
import os.signpost
/// Variational Bayes clustering (VBx) for speaker diarization.
///
/// This implementation is based on the VBx algorithm from BUT Speech@FIT
/// (Brno University of Technology Speech@FIT group).
///
/// Reference:
/// - Original paper: "Improved Speaker Diarization Using a Deep Learning-based Approach"
/// - GitHub repository: https://github.com/BUTSpeechFIT/VBx
/// - License: Apache License 2.0
/// - Copyright 2021-2024 BUT Speech@FIT
///
/// The algorithm uses variational inference to cluster speaker embeddings with:
/// - Probabilistic Linear Discriminant Analysis (PLDA) transformation
/// - Expectation-Maximization (EM) iterations with convergence monitoring
/// - Evidence Lower Bound (ELBO) tracking for convergence
/// - Speaker mixture weight estimation (pi parameters)
///
/// The implementation includes warm-start initialization from initial hard cluster
/// assignments and supports PLDA whitening transformation of input features.
struct VBxClustering {
private let config: OfflineDiarizerConfig
private let pldaTransform: PLDATransform
private let logger = AppLogger(category: "OfflineVBx")
private let signposter = OSSignposter(
subsystem: "com.fluidaudio.diarization",
category: .pointsOfInterest
)
init(config: OfflineDiarizerConfig, pldaTransform: PLDATransform) {
self.config = config
self.pldaTransform = pldaTransform
}
// MARK: - VBx Clustering Algorithm
func refine(
rhoFeatures: [[Double]],
initialClusters: [Int]
) -> VBxOutput {
guard !rhoFeatures.isEmpty else {
return VBxOutput(
gamma: [],
pi: [],
hardClusters: [],
centroids: [],
numClusters: 0,
elbos: []
)
}
let frameCount = rhoFeatures.count
guard let dimension = rhoFeatures.first?.count, dimension > 0 else {
logger.error("VBx received empty feature vectors")
return VBxOutput(
gamma: [],
pi: [],
hardClusters: [],
centroids: [],
numClusters: 0,
elbos: []
)
}
let vbxState = signposter.beginInterval("VBx Clustering Algorithm")
var phi = pldaTransform.phiParameters
if phi.count != dimension {
logger.warning(
"PLDA psi dimension (\(phi.count)) mismatches rho dimension (\(dimension)); falling back to identity")
phi = Array(repeating: 1.0, count: dimension)
}
let speakerCount = max(1, Set(initialClusters).count)
let histogram = initialClusters.reduce(into: [:]) { partialResult, value in
partialResult[value, default: 0] += 1
}
logger.debug("VBx warm start clusters: \(speakerCount) histogram: \(histogram)")
var featureBuffer = [Double](repeating: 0, count: frameCount * dimension)
featureBuffer.withUnsafeMutableBufferPointer { bufferPtr in
guard let baseAddress = bufferPtr.baseAddress else { return }
for (index, frame) in rhoFeatures.enumerated() {
let destination = baseAddress.advanced(by: index * dimension)
frame.withUnsafeBufferPointer { source in
guard let sourceBase = source.baseAddress else { return }
memcpy(
destination,
sourceBase,
dimension * MemoryLayout<Double>.size
)
}
}
}
var initialGamma = [Double](repeating: 0, count: frameCount * speakerCount)
if !initialClusters.isEmpty {
for (index, cluster) in initialClusters.enumerated() {
let speaker = max(0, min(cluster, speakerCount - 1))
initialGamma[index * speakerCount + speaker] = 1.0
}
} else {
let uniform = 1.0 / Double(speakerCount)
for index in 0..<frameCount {
for speaker in 0..<speakerCount {
initialGamma[index * speakerCount + speaker] = uniform
}
}
}
let gammaSource: [Double]
let piSource: [Double]
let elboHistory: [Double]
do {
let result = try runVBx(
features: featureBuffer,
frameCount: frameCount,
dimension: dimension,
phi: phi,
initialGamma: initialGamma,
speakerCount: speakerCount,
maxIterations: config.vbx.maxIterations,
epsilon: config.vbx.convergenceTolerance,
Fa: config.clustering.warmStartFa,
Fb: config.clustering.warmStartFb,
initSmoothing: 7.0
)
gammaSource = result.gamma
piSource = result.pi
elboHistory = result.elbos
} catch {
logger.error("VBx failed to prepare BLAS arguments: \(error.localizedDescription)")
gammaSource = initialGamma
piSource = Array(repeating: 1.0 / Double(speakerCount), count: speakerCount)
elboHistory = []
}
let gammaMatrix = reshapeGamma(gammaSource, frameCount: frameCount, speakerCount: speakerCount)
let hardAssignments = gammaMatrix.map { row -> Int in
row.enumerated().max(by: { $0.element < $1.element })?.offset ?? 0
}
if let maxPi = piSource.max(), let minPi = piSource.min() {
logger.debug("VBx mixture weights min: \(minPi), max: \(maxPi), count: \(piSource.count)")
} else {
logger.debug("VBx mixture weights unavailable")
}
let output = VBxOutput(
gamma: gammaMatrix,
pi: piSource,
hardClusters: [hardAssignments],
centroids: [],
numClusters: speakerCount,
elbos: elboHistory
)
signposter.endInterval("VBx Clustering Algorithm", vbxState)
return output
}
private func runVBx(
features: [Double],
frameCount: Int,
dimension: Int,
phi: [Double],
initialGamma: [Double],
speakerCount: Int,
maxIterations: Int,
epsilon: Double,
Fa: Double,
Fb: Double,
initSmoothing: Double
) throws -> (gamma: [Double], pi: [Double], elbos: [Double]) {
var gamma = initialGamma
let frameCountBlas = try makeBlasIndex(frameCount, label: "VBx frame count")
let speakerCountBlas = try makeBlasIndex(speakerCount, label: "VBx speaker count")
let dimensionBlas = try makeBlasIndex(dimension, label: "VBx feature dimension")
let speakerLength = vDSP_Length(speakerCount)
let dimensionLength = vDSP_Length(dimension)
let onesFrame = [Double](repeating: 1.0, count: frameCount)
var rowBuffer = [Double](repeating: 0, count: speakerCount)
if initSmoothing >= 0.0 {
gamma.withUnsafeMutableBufferPointer { gammaPtr in
rowBuffer.withUnsafeMutableBufferPointer { bufferPtr in
guard
let gammaBase = gammaPtr.baseAddress,
let scratch = bufferPtr.baseAddress
else { return }
for t in 0..<frameCount {
let row = gammaBase.advanced(by: t * speakerCount)
var multiplier = initSmoothing
vDSP_vsmulD(row, 1, &multiplier, scratch, 1, speakerLength)
var maxValue = -Double.greatestFiniteMagnitude
vDSP_maxvD(scratch, 1, &maxValue, speakerLength)
var shift = -maxValue
vDSP_vsaddD(scratch, 1, &shift, scratch, 1, speakerLength)
var count = Int32(speakerCount)
vvexp(scratch, scratch, &count)
var sumExp = 0.0
vDSP_sveD(scratch, 1, &sumExp, speakerLength)
if sumExp <= 0.0 || !sumExp.isFinite {
var uniform = 1.0 / Double(speakerCount)
vDSP_vfillD(&uniform, row, 1, speakerLength)
} else {
var invSum = 1.0 / sumExp
vDSP_vsmulD(scratch, 1, &invSum, row, 1, speakerLength)
}
}
}
}
}
gamma.withUnsafeMutableBufferPointer { gammaPtr in
guard let gammaBase = gammaPtr.baseAddress else { return }
for t in 0..<frameCount {
let row = gammaBase.advanced(by: t * speakerCount)
var sum = 0.0
vDSP_sveD(row, 1, &sum, speakerLength)
if sum <= 0.0 || !sum.isFinite {
var uniform = 1.0 / Double(speakerCount)
vDSP_vfillD(&uniform, row, 1, speakerLength)
} else {
var inv = 1.0 / sum
vDSP_vsmulD(row, 1, &inv, row, 1, speakerLength)
}
}
}
var pi = [Double](repeating: 1.0 / Double(speakerCount), count: speakerCount)
let phiClamped = phi.map { max($0, 1e-12) }
let sqrtPhi = phiClamped.map { sqrt($0) }
var rho = [Double](repeating: 0, count: features.count)
features.withUnsafeBufferPointer { featurePtr in
rho.withUnsafeMutableBufferPointer { rhoPtr in
sqrtPhi.withUnsafeBufferPointer { sqrtPtr in
guard
let featureBase = featurePtr.baseAddress,
let rhoBase = rhoPtr.baseAddress,
let sqrtBase = sqrtPtr.baseAddress
else { return }
for t in 0..<frameCount {
let offset = t * dimension
vDSP_vmulD(
featureBase.advanced(by: offset),
1,
sqrtBase,
1,
rhoBase.advanced(by: offset),
1,
dimensionLength
)
}
}
}
}
var G = [Double](repeating: 0, count: frameCount)
let logConstant = Double(dimension) * log(2.0 * Double.pi)
features.withUnsafeBufferPointer { featurePtr in
guard let featureBase = featurePtr.baseAddress else { return }
for t in 0..<frameCount {
let offset = t * dimension
var sumSq: Double = 0
vDSP_svesqD(
featureBase.advanced(by: offset),
1,
&sumSq,
dimensionLength
)
G[t] = -0.5 * (sumSq + logConstant)
}
}
let ratio = Fa / Fb
var invL = [Double](repeating: 0, count: speakerCount * dimension)
var alpha = [Double](repeating: 0, count: speakerCount * dimension)
var temp = [Double](repeating: 0, count: speakerCount * dimension)
var phiTerms = [Double](repeating: 0, count: speakerCount)
var gammaSum = [Double](repeating: 0, count: speakerCount)
var logP = [Double](repeating: 0, count: frameCount * speakerCount)
var phiScratch = [Double](repeating: 0, count: dimension)
var phiOffset = [Double](repeating: 0, count: speakerCount)
var logInv = [Double](repeating: 0, count: speakerCount * dimension)
var elbos = [Double](repeating: 0, count: max(maxIterations, 1))
let alphaCount = alpha.count
let invLCount = invL.count
var previousElbo = -Double.greatestFiniteMagnitude
var iterations = 0
for iteration in 0..<maxIterations {
iterations = iteration + 1
gamma.withUnsafeBufferPointer { gammaPtr in
onesFrame.withUnsafeBufferPointer { onesPtr in
gammaSum.withUnsafeMutableBufferPointer { sumPtr in
guard
let gammaBase = gammaPtr.baseAddress,
let onesBase = onesPtr.baseAddress,
let sumBase = sumPtr.baseAddress
else { return }
cblas_dgemv(
CblasRowMajor,
CblasTrans,
frameCountBlas,
speakerCountBlas,
1.0,
gammaBase,
speakerCountBlas,
onesBase,
1,
0.0,
sumBase,
1
)
}
}
}
for s in 0..<speakerCount {
let weight = ratio * gammaSum[s]
for d in 0..<dimension {
let idx = s * dimension + d
let denom = 1.0 + weight * phiClamped[d]
invL[idx] = 1.0 / max(denom, 1e-12)
}
}
gamma.withUnsafeBufferPointer { gammaPtr in
rho.withUnsafeBufferPointer { rhoPtr in
temp.withUnsafeMutableBufferPointer { tempPtr in
cblas_dgemm(
CblasRowMajor,
CblasTrans,
CblasNoTrans,
speakerCountBlas,
dimensionBlas,
frameCountBlas,
1.0,
gammaPtr.baseAddress!,
speakerCountBlas,
rhoPtr.baseAddress!,
dimensionBlas,
0.0,
tempPtr.baseAddress!,
dimensionBlas
)
}
}
}
alpha.withUnsafeMutableBufferPointer { alphaPtr in
invL.withUnsafeBufferPointer { invPtr in
temp.withUnsafeBufferPointer { tempPtr in
guard
let alphaBase = alphaPtr.baseAddress,
let invBase = invPtr.baseAddress,
let tempBase = tempPtr.baseAddress
else { return }
vDSP_vmulD(
tempBase,
1,
invBase,
1,
alphaBase,
1,
vDSP_Length(alphaCount)
)
var ratioScalar = ratio
vDSP_vsmulD(
alphaBase,
1,
&ratioScalar,
alphaBase,
1,
vDSP_Length(alphaCount)
)
}
}
}
alpha.withUnsafeBufferPointer { alphaPtr in
invL.withUnsafeBufferPointer { invPtr in
phiClamped.withUnsafeBufferPointer { phiPtr in
phiScratch.withUnsafeMutableBufferPointer { scratchPtr in
guard
let alphaBase = alphaPtr.baseAddress,
let invBase = invPtr.baseAddress,
let phiBase = phiPtr.baseAddress,
let scratchBase = scratchPtr.baseAddress
else { return }
for s in 0..<speakerCount {
let offset = s * dimension
vDSP_vsqD(
alphaBase.advanced(by: offset),
1,
scratchBase,
1,
dimensionLength
)
vDSP_vaddD(
scratchBase,
1,
invBase.advanced(by: offset),
1,
scratchBase,
1,
dimensionLength
)
vDSP_vmulD(
scratchBase,
1,
phiBase,
1,
scratchBase,
1,
dimensionLength
)
var sum: Double = 0
vDSP_sveD(scratchBase, 1, &sum, dimensionLength)
phiTerms[s] = sum
}
}
}
}
}
rho.withUnsafeBufferPointer { rhoPtr in
alpha.withUnsafeBufferPointer { alphaPtr in
logP.withUnsafeMutableBufferPointer { logPtr in
cblas_dgemm(
CblasRowMajor,
CblasNoTrans,
CblasTrans,
frameCountBlas,
speakerCountBlas,
dimensionBlas,
1.0,
rhoPtr.baseAddress!,
dimensionBlas,
alphaPtr.baseAddress!,
dimensionBlas,
0.0,
logPtr.baseAddress!,
speakerCountBlas
)
}
}
}
phiTerms.withUnsafeBufferPointer { phiPtr in
phiOffset.withUnsafeMutableBufferPointer { offsetPtr in
guard
let phiBase = phiPtr.baseAddress,
let offsetBase = offsetPtr.baseAddress
else { return }
var negativeHalf: Double = -0.5
vDSP_vsmulD(
phiBase,
1,
&negativeHalf,
offsetBase,
1,
speakerLength
)
}
}
phiOffset.withUnsafeBufferPointer { offsetPtr in
logP.withUnsafeMutableBufferPointer { logPtr in
guard
let offsetBase = offsetPtr.baseAddress,
let logBase = logPtr.baseAddress
else { return }
for t in 0..<frameCount {
let row = logBase.advanced(by: t * speakerCount)
vDSP_vaddD(row, 1, offsetBase, 1, row, 1, speakerLength)
var g = G[t]
vDSP_vsaddD(row, 1, &g, row, 1, speakerLength)
var faScale = Fa
vDSP_vsmulD(row, 1, &faScale, row, 1, speakerLength)
}
}
}
var logLikelihood = 0.0
var logPi = [Double](repeating: 0, count: speakerCount)
pi.withUnsafeBufferPointer { piPtr in
logPi.withUnsafeMutableBufferPointer { logPtr in
guard
let piBase = piPtr.baseAddress,
let logBase = logPtr.baseAddress
else { return }
var threshold = 1e-8
vDSP_vthrD(
piBase,
1,
&threshold,
logBase,
1,
speakerLength
)
var count = Int32(speakerCount)
vvlog(logBase, logBase, &count)
}
}
rowBuffer.withUnsafeMutableBufferPointer { bufferPtr in
gamma.withUnsafeMutableBufferPointer { gammaPtr in
logP.withUnsafeBufferPointer { logPtr in
logPi.withUnsafeBufferPointer { logPiPtr in
guard
let scratch = bufferPtr.baseAddress,
let gammaBase = gammaPtr.baseAddress,
let logBase = logPtr.baseAddress,
let logPiBase = logPiPtr.baseAddress
else { return }
for t in 0..<frameCount {
let rowOffset = t * speakerCount
let gammaRow = gammaBase.advanced(by: rowOffset)
vDSP_mmovD(
logBase.advanced(by: rowOffset),
scratch,
speakerLength,
1,
speakerLength,
1
)
vDSP_vaddD(
scratch,
1,
logPiBase,
1,
scratch,
1,
speakerLength
)
var rowMax = -Double.greatestFiniteMagnitude
vDSP_maxvD(scratch, 1, &rowMax, speakerLength)
var shift = -rowMax
vDSP_vsaddD(scratch, 1, &shift, scratch, 1, speakerLength)
var count = Int32(speakerCount)
vvexp(scratch, scratch, &count)
var sumExp = 0.0
vDSP_sveD(scratch, 1, &sumExp, speakerLength)
if sumExp <= 0.0 || !sumExp.isFinite {
var uniform = 1.0 / Double(speakerCount)
vDSP_vfillD(&uniform, gammaRow, 1, speakerLength)
logLikelihood += rowMax
} else {
var invSum = 1.0 / sumExp
vDSP_vsmulD(
scratch,
1,
&invSum,
gammaRow,
1,
speakerLength
)
logLikelihood += rowMax + log(sumExp)
}
}
}
}
}
}
gamma.withUnsafeBufferPointer { gammaPtr in
onesFrame.withUnsafeBufferPointer { onesPtr in
pi.withUnsafeMutableBufferPointer { piPtr in
guard
let gammaBase = gammaPtr.baseAddress,
let onesBase = onesPtr.baseAddress,
let piBase = piPtr.baseAddress
else { return }
cblas_dgemv(
CblasRowMajor,
CblasTrans,
frameCountBlas,
speakerCountBlas,
1.0,
gammaBase,
speakerCountBlas,
onesBase,
1,
0.0,
piBase,
1
)
}
}
}
var piSum = 0.0
pi.withUnsafeBufferPointer { piPtr in
guard let piBase = piPtr.baseAddress else { return }
vDSP_sveD(piBase, 1, &piSum, speakerLength)
}
if piSum > 0.0 && piSum.isFinite {
var inv = 1.0 / piSum
pi.withUnsafeMutableBufferPointer { piPtr in
guard let piBase = piPtr.baseAddress else { return }
vDSP_vsmulD(piBase, 1, &inv, piBase, 1, speakerLength)
}
} else {
var uniform = 1.0 / Double(speakerCount)
pi.withUnsafeMutableBufferPointer { piPtr in
guard let piBase = piPtr.baseAddress else { return }
vDSP_vfillD(&uniform, piBase, 1, speakerLength)
}
}
var sumLogInv = 0.0
var sumInv = 0.0
var sumAlphaSq = 0.0
invL.withUnsafeBufferPointer { invPtr in
logInv.withUnsafeMutableBufferPointer { logPtr in
guard
let invBase = invPtr.baseAddress,
let logBase = logPtr.baseAddress
else { return }
logBase.update(from: invBase, count: invLCount)
var count = Int32(invLCount)
vvlog(logBase, logBase, &count)
vDSP_sveD(logBase, 1, &sumLogInv, vDSP_Length(invLCount))
vDSP_sveD(invBase, 1, &sumInv, vDSP_Length(invLCount))
}
}
alpha.withUnsafeBufferPointer { alphaPtr in
guard let alphaBase = alphaPtr.baseAddress else { return }
vDSP_svesqD(alphaBase, 1, &sumAlphaSq, vDSP_Length(alphaCount))
}
var elbo = logLikelihood
elbo += Fb * 0.5 * (sumLogInv - sumInv - sumAlphaSq + Double(invLCount))
if iteration < elbos.count {
elbos[iteration] = elbo
}
if iteration > 0 {
let improvement = elbo - previousElbo
if abs(improvement) < epsilon {
previousElbo = elbo
break
}
}
previousElbo = elbo
}
return (gamma, pi, Array(elbos.prefix(iterations)))
}
private func reshapeGamma(_ buffer: [Double], frameCount: Int, speakerCount: Int) -> [[Double]] {
var result: [[Double]] = []
result.reserveCapacity(frameCount)
for frame in 0..<frameCount {
let start = frame * speakerCount
let row = Array(buffer[start..<(start + speakerCount)])
result.append(row)
}
return result
}
}
@@ -0,0 +1,356 @@
import Accelerate
import Foundation
/// Thin wrapper around common vDSP routines used by the offline diarization
/// pipeline. Centralising this logic keeps the clustering implementation more
/// readable while guaranteeing we stay away from unsafe pointer juggling in
/// the hot path.
enum VDSPOperations {
private static let epsilon: Float = 1e-12
static func l2Normalize(_ input: [Float]) -> [Float] {
guard !input.isEmpty else { return input }
var dot: Float = 0
vDSP_dotpr(input, 1, input, 1, &dot, vDSP_Length(input.count))
let norm = max(sqrt(dot), epsilon)
var scale = 1 / norm
var output = [Float](repeating: 0, count: input.count)
vDSP_vsmul(input, 1, &scale, &output, 1, vDSP_Length(input.count))
return output
}
static func dotProduct(_ lhs: [Float], _ rhs: [Float]) -> Float {
precondition(lhs.count == rhs.count, "Vectors must have the same dimension")
var dot: Float = 0
vDSP_dotpr(lhs, 1, rhs, 1, &dot, vDSP_Length(lhs.count))
return dot
}
static func matrixVectorMultiply(matrix: [[Float]], vector: [Float]) -> [Float] {
guard let columns = matrix.first?.count, !matrix.isEmpty else { return [] }
precondition(columns == vector.count, "Dimension mismatch")
if columns == 0 {
return [Float](repeating: 0, count: matrix.count)
}
let flatMatrix = matrix.flatMap { row in
precondition(row.count == columns, "Jagged matrix not supported")
return row
}
let rowCount = makeBlasIndexOrFatal(matrix.count, label: "matrix row count")
let columnCount = makeBlasIndexOrFatal(columns, label: "matrix column count")
let unitStride = BlasIndex(1)
var result = [Float](repeating: 0, count: matrix.count)
flatMatrix.withUnsafeBufferPointer { matrixPointer in
vector.withUnsafeBufferPointer { vectorPointer in
result.withUnsafeMutableBufferPointer { resultPointer in
cblas_sgemv(
CblasRowMajor,
CblasNoTrans,
rowCount,
columnCount,
1.0,
matrixPointer.baseAddress!,
columnCount,
vectorPointer.baseAddress!,
unitStride,
0.0,
resultPointer.baseAddress!,
unitStride
)
}
}
}
return result
}
static func matrixMultiply(a: [[Float]], b: [[Float]]) -> [[Float]] {
guard
let aColumns = a.first?.count,
!a.isEmpty,
!b.isEmpty
else {
return []
}
precondition(
aColumns == b.count,
"Inner dimensions must match for matrix multiplication"
)
if aColumns == 0 || b.first?.isEmpty == true {
return Array(
repeating: Array(repeating: 0 as Float, count: b.first?.count ?? 0),
count: a.count
)
}
let rowsA = a.count
let columnsB = b.first!.count
let flatA = a.flatMap { row in
precondition(row.count == aColumns, "Jagged matrix not supported")
return row
}
let flatB = b.flatMap { row in
precondition(row.count == columnsB, "Jagged matrix not supported")
return row
}
let rowsAIndex = makeBlasIndexOrFatal(rowsA, label: "matrixMultiply rowsA")
let columnsBIndex = makeBlasIndexOrFatal(columnsB, label: "matrixMultiply columnsB")
let aColumnsIndex = makeBlasIndexOrFatal(aColumns, label: "matrixMultiply columnsA")
var flatResult = [Float](repeating: 0, count: rowsA * columnsB)
flatA.withUnsafeBufferPointer { aPointer in
flatB.withUnsafeBufferPointer { bPointer in
flatResult.withUnsafeMutableBufferPointer { resultPointer in
cblas_sgemm(
CblasRowMajor,
CblasNoTrans,
CblasNoTrans,
rowsAIndex,
columnsBIndex,
aColumnsIndex,
1.0,
aPointer.baseAddress!,
aColumnsIndex,
bPointer.baseAddress!,
columnsBIndex,
0.0,
resultPointer.baseAddress!,
columnsBIndex
)
}
}
}
var result = Array(
repeating: Array(repeating: 0 as Float, count: columnsB),
count: rowsA
)
for rowIndex in 0..<rowsA {
let base = rowIndex * columnsB
for columnIndex in 0..<columnsB {
result[rowIndex][columnIndex] = flatResult[base + columnIndex]
}
}
return result
}
static func logSumExp(_ input: [Float]) -> Float {
guard let maxElement = input.max() else { return -Float.infinity }
var sum: Float = 0
let shift = -maxElement
var mutableShift = shift
var shifted = [Float](repeating: 0, count: input.count)
vDSP_vsadd(input, 1, &mutableShift, &shifted, 1, vDSP_Length(input.count))
var count = Int32(input.count)
vvexpf(&shifted, shifted, &count)
vDSP_sve(shifted, 1, &sum, vDSP_Length(input.count))
return log(sum) + maxElement
}
static func softmax(_ input: [Float]) -> [Float] {
guard let maxElement = input.max() else { return [] }
let shift = -maxElement
var mutableShift = shift
var shifted = [Float](repeating: 0, count: input.count)
var sum: Float = 0
vDSP_vsadd(input, 1, &mutableShift, &shifted, 1, vDSP_Length(input.count))
var count = Int32(input.count)
vvexpf(&shifted, shifted, &count)
vDSP_sve(shifted, 1, &sum, vDSP_Length(input.count))
guard sum > 0 else {
return Array(repeating: 1.0 / Float(input.count), count: input.count)
}
var scale = 1 / sum
vDSP_vsmul(shifted, 1, &scale, &shifted, 1, vDSP_Length(input.count))
return shifted
}
static func sum(_ input: [Float]) -> Float {
guard !input.isEmpty else { return 0 }
var total: Float = 0
vDSP_sve(input, 1, &total, vDSP_Length(input.count))
return total
}
static func sum(_ input: [Double]) -> Double {
guard !input.isEmpty else { return 0 }
var total: Double = 0
vDSP_sveD(input, 1, &total, vDSP_Length(input.count))
return total
}
static func pairwiseEuclideanDistances(a: [[Float]], b: [[Float]]) -> [[Float]] {
guard let dimension = a.first?.count, dimension == b.first?.count else {
return []
}
let rowsA = a.count
let rowsB = b.count
if rowsA == 0 || rowsB == 0 || dimension == 0 {
return Array(
repeating: Array(repeating: 0 as Float, count: rowsB),
count: rowsA
)
}
let flatA = a.flatMap { row in
precondition(row.count == dimension, "Jagged matrix not supported")
return row
}
let flatB = b.flatMap { row in
precondition(row.count == dimension, "Jagged matrix not supported")
return row
}
var normsA = [Float](repeating: 0, count: rowsA)
var normsB = [Float](repeating: 0, count: rowsB)
flatA.withUnsafeBufferPointer { pointer in
guard let base = pointer.baseAddress else { return }
for row in 0..<rowsA {
vDSP_svesq(
base.advanced(by: row * dimension),
1,
&normsA[row],
vDSP_Length(dimension)
)
}
}
flatB.withUnsafeBufferPointer { pointer in
guard let base = pointer.baseAddress else { return }
for row in 0..<rowsB {
vDSP_svesq(
base.advanced(by: row * dimension),
1,
&normsB[row],
vDSP_Length(dimension)
)
}
}
let rowsAIndex = makeBlasIndexOrFatal(rowsA, label: "distance rowsA")
let rowsBIndex = makeBlasIndexOrFatal(rowsB, label: "distance rowsB")
let dimensionIndex = makeBlasIndexOrFatal(dimension, label: "distance dimension")
var dotProducts = [Float](repeating: 0, count: rowsA * rowsB)
flatA.withUnsafeBufferPointer { aPointer in
flatB.withUnsafeBufferPointer { bPointer in
dotProducts.withUnsafeMutableBufferPointer { resultPointer in
cblas_sgemm(
CblasRowMajor,
CblasNoTrans,
CblasTrans,
rowsAIndex,
rowsBIndex,
dimensionIndex,
1.0,
aPointer.baseAddress!,
dimensionIndex,
bPointer.baseAddress!,
dimensionIndex,
0.0,
resultPointer.baseAddress!,
rowsBIndex
)
}
}
}
var negativeTwo: Float = -2
let dotElementCount = dotProducts.count
dotProducts.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
vDSP_vsmul(
baseAddress,
1,
&negativeTwo,
baseAddress,
1,
vDSP_Length(dotElementCount)
)
}
normsB.withUnsafeBufferPointer { normsBPointer in
guard let normsBBase = normsBPointer.baseAddress else { return }
dotProducts.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
for rowIndex in 0..<rowsA {
let rowPointer = baseAddress.advanced(by: rowIndex * rowsB)
var normA = normsA[rowIndex]
vDSP_vsadd(
rowPointer,
1,
&normA,
rowPointer,
1,
vDSP_Length(rowsB)
)
vDSP_vadd(
rowPointer,
1,
normsBBase,
1,
rowPointer,
1,
vDSP_Length(rowsB)
)
}
}
}
var zero: Float = 0
dotProducts.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
vDSP_vthres(
baseAddress,
1,
&zero,
baseAddress,
1,
vDSP_Length(dotElementCount)
)
}
dotProducts.withUnsafeMutableBufferPointer { pointer in
guard let baseAddress = pointer.baseAddress else { return }
var elementCount = Int32(pointer.count)
vvsqrtf(baseAddress, baseAddress, &elementCount)
}
var result = Array(
repeating: Array(repeating: 0 as Float, count: rowsB),
count: rowsA
)
for rowIndex in 0..<rowsA {
let base = rowIndex * rowsB
for columnIndex in 0..<rowsB {
result[rowIndex][columnIndex] = dotProducts[base + columnIndex]
}
}
return result
}
}
@@ -0,0 +1,147 @@
import Foundation
/// Utility helpers to resample soft VAD weights with the same half-pixel offset
/// mapping used by `scipy.ndimage.zoom`. The offline diarization pipeline relies
/// on these helpers to match the reference implementation produced by the
/// Pyannote/Core ML exporters when interpolating segmentation masks.
enum WeightInterpolation {
/// Pre-computed interpolation metadata that allows repeated resampling with
/// minimal per-call overhead.
struct InterpolationCoefficients {
private let leftIndices: [Int]
private let rightIndices: [Int]
private let leftWeights: [Float]
private let rightWeights: [Float]
private let maxInputIndex: Int
private let outputLength: Int
init(inputLength: Int, outputLength: Int) {
precondition(inputLength > 0 && outputLength > 0, "Lengths must be positive")
var left: [Int] = []
var right: [Int] = []
var lWeights: [Float] = []
var rWeights: [Float] = []
var maxIndex = 0
let scale = Float(outputLength) / Float(inputLength)
left.reserveCapacity(outputLength)
right.reserveCapacity(outputLength)
lWeights.reserveCapacity(outputLength)
rWeights.reserveCapacity(outputLength)
for index in 0..<outputLength {
// Half-pixel offset mapping to match scipy.ndimage.zoom(order=1)
let position = (Float(index) + 0.5) / scale - 0.5
let clamped = min(max(position, 0), Float(inputLength - 1))
let leftIndex = Int(floor(clamped))
let rightIndex = min(leftIndex + 1, inputLength - 1)
let weightRight = clamped - Float(leftIndex)
let weightLeft = 1 - weightRight
maxIndex = max(maxIndex, rightIndex)
left.append(leftIndex)
right.append(rightIndex)
lWeights.append(weightLeft)
rWeights.append(weightRight)
}
self.leftIndices = left
self.rightIndices = right
self.leftWeights = lWeights
self.rightWeights = rWeights
self.maxInputIndex = maxIndex
self.outputLength = outputLength
}
func interpolate(_ input: [Float]) -> [Float] {
guard !input.isEmpty else { return [] }
precondition(input.count > maxInputIndex, "Input shorter than interpolation map")
var output = [Float](repeating: 0, count: outputLength)
for index in 0..<outputLength {
let leftValue = input[leftIndices[index]]
let rightValue = input[rightIndices[index]]
output[index] = leftValue * leftWeights[index] + rightValue * rightWeights[index]
}
return output
}
func interpolate(_ input: [Float], into output: inout [Float]) {
guard !input.isEmpty else {
if output.count == outputLength {
for index in 0..<outputLength {
output[index] = 0
}
} else {
output = [Float](repeating: 0, count: outputLength)
}
return
}
precondition(input.count > maxInputIndex, "Input shorter than interpolation map")
if output.count != outputLength {
output = [Float](repeating: 0, count: outputLength)
}
for index in 0..<outputLength {
let leftValue = input[leftIndices[index]]
let rightValue = input[rightIndices[index]]
output[index] = leftValue * leftWeights[index] + rightValue * rightWeights[index]
}
}
}
/// Resample a 1-dimensional array to the requested length.
static func resample(_ input: [Float], to outputLength: Int) -> [Float] {
guard !input.isEmpty, outputLength > 0 else {
return []
}
if input.count == outputLength {
return input
}
let coefficients = InterpolationCoefficients(
inputLength: input.count,
outputLength: outputLength
)
return coefficients.interpolate(input)
}
/// Resample each row independently using the same interpolation map.
static func resample2D(_ input: [[Float]], to outputLength: Int) -> [[Float]] {
guard let firstRow = input.first, !firstRow.isEmpty, outputLength > 0 else {
return []
}
let coefficients = InterpolationCoefficients(
inputLength: firstRow.count,
outputLength: outputLength
)
return input.map { row in
if row.count == firstRow.count {
return coefficients.interpolate(row)
} else {
// Fallback to per-row computation if the layout differs
return resample(row, to: outputLength)
}
}
}
/// Convenience helper mirroring `scipy.ndimage.zoom` for linear interpolation.
static func zoom(_ input: [Float], factor: Float) -> [Float] {
guard !input.isEmpty, factor > 0 else {
return []
}
let outputLength = max(1, Int(round(Float(input.count) * factor)))
return resample(input, to: outputLength)
}
}
@@ -1,3 +1,4 @@
import Accelerate
import Foundation import Foundation
import OSLog import OSLog
@@ -76,13 +77,15 @@ public class SpeakerManager {
return nil return nil
} }
let normalizedEmbedding = VDSPOperations.l2Normalize(embedding)
return queue.sync(flags: .barrier) { return queue.sync(flags: .barrier) {
let (closestSpeaker, distance) = findClosestSpeaker(to: embedding) let (closestSpeaker, distance) = findClosestSpeaker(to: normalizedEmbedding)
if let speakerId = closestSpeaker, distance < speakerThreshold { if let speakerId = closestSpeaker, distance < speakerThreshold {
updateExistingSpeaker( updateExistingSpeaker(
speakerId: speakerId, speakerId: speakerId,
embedding: embedding, embedding: normalizedEmbedding,
duration: speechDuration, duration: speechDuration,
distance: distance distance: distance
) )
@@ -96,7 +99,7 @@ public class SpeakerManager {
// Step 3: Create new speaker if duration is sufficient // Step 3: Create new speaker if duration is sufficient
if speechDuration >= minSpeechDuration { if speechDuration >= minSpeechDuration {
let newSpeakerId = createNewSpeaker( let newSpeakerId = createNewSpeaker(
embedding: embedding, embedding: normalizedEmbedding,
duration: speechDuration, duration: speechDuration,
distanceToClosest: distance distanceToClosest: distance
) )
@@ -142,8 +145,9 @@ public class SpeakerManager {
// Update embedding if quality is good // Update embedding if quality is good
if distance < embeddingThreshold { if distance < embeddingThreshold {
let embeddingMagnitude = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) var sumSquares: Float = 0
if embeddingMagnitude > 0.1 { vDSP_svesq(embedding, 1, &sumSquares, vDSP_Length(embedding.count))
if sumSquares > 0.01 {
speaker.updateMainEmbedding( speaker.updateMainEmbedding(
duration: duration, duration: duration,
embedding: embedding, embedding: embedding,
@@ -165,6 +169,7 @@ public class SpeakerManager {
duration: Float, duration: Float,
distanceToClosest: Float distanceToClosest: Float
) -> String { ) -> String {
let normalizedEmbedding = VDSPOperations.l2Normalize(embedding)
let newSpeakerId = String(nextSpeakerId) let newSpeakerId = String(nextSpeakerId)
nextSpeakerId += 1 nextSpeakerId += 1
highestSpeakerId = max(highestSpeakerId, nextSpeakerId - 1) highestSpeakerId = max(highestSpeakerId, nextSpeakerId - 1)
@@ -173,12 +178,12 @@ public class SpeakerManager {
let newSpeaker = Speaker( let newSpeaker = Speaker(
id: newSpeakerId, id: newSpeakerId,
name: "Speaker \(newSpeakerId)", // Default name with number name: "Speaker \(newSpeakerId)", // Default name with number
currentEmbedding: embedding, currentEmbedding: normalizedEmbedding,
duration: duration duration: duration
) )
// Add initial raw embedding // Add initial raw embedding
let initialRaw = RawEmbedding(segmentId: UUID(), embedding: embedding, timestamp: Date()) let initialRaw = RawEmbedding(segmentId: UUID(), embedding: normalizedEmbedding, timestamp: Date())
newSpeaker.addRawEmbedding(initialRaw) newSpeaker.addRawEmbedding(initialRaw)
speakerDatabase[newSpeakerId] = newSpeaker speakerDatabase[newSpeakerId] = newSpeaker
@@ -1,3 +1,4 @@
import Accelerate
import Foundation import Foundation
import OSLog import OSLog
@@ -7,6 +8,7 @@ import OSLog
public enum SpeakerUtilities { public enum SpeakerUtilities {
private static let logger = AppLogger(category: "SpeakerUtilities") private static let logger = AppLogger(category: "SpeakerUtilities")
private static let normalizationTolerance: Float = 1e-3
// MARK: - Configuration // MARK: - Configuration
@@ -64,25 +66,38 @@ public enum SpeakerUtilities {
} }
var dotProduct: Float = 0 var dotProduct: Float = 0
var magnitudeA: Float = 0 vDSP_dotpr(a, 1, b, 1, &dotProduct, vDSP_Length(a.count))
var magnitudeB: Float = 0
for i in 0..<a.count { var sumSquaresA: Float = 0
dotProduct += a[i] * b[i] var sumSquaresB: Float = 0
magnitudeA += a[i] * a[i] vDSP_svesq(a, 1, &sumSquaresA, vDSP_Length(a.count))
magnitudeB += b[i] * b[i] vDSP_svesq(b, 1, &sumSquaresB, vDSP_Length(b.count))
}
magnitudeA = sqrt(magnitudeA) guard sumSquaresA > 0 && sumSquaresB > 0 else {
magnitudeB = sqrt(magnitudeB)
guard magnitudeA > 0 && magnitudeB > 0 else {
logger.warning("Zero magnitude embedding detected") logger.warning("Zero magnitude embedding detected")
return Float.infinity return Float.infinity
} }
let similarity = dotProduct / (magnitudeA * magnitudeB) let isUnitA = abs(sumSquaresA - 1.0) <= normalizationTolerance
return 1 - similarity let isUnitB = abs(sumSquaresB - 1.0) <= normalizationTolerance
let similarity: Float
if isUnitA && isUnitB {
similarity = dotProduct
} else {
let magnitudeA = sumSquaresA.squareRoot()
let magnitudeB = sumSquaresB.squareRoot()
guard magnitudeA > 0 && magnitudeB > 0 else {
logger.warning("Zero magnitude after normalization guard")
return Float.infinity
}
similarity = dotProduct / (magnitudeA * magnitudeB)
}
let clampedSimilarity = min(max(similarity, -1.0), 1.0)
return 1 - clampedSimilarity
} }
// MARK: - Embedding Validation // MARK: - Embedding Validation
@@ -94,7 +109,9 @@ public enum SpeakerUtilities {
return false return false
} }
let magnitude = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) var sumSquares: Float = 0
vDSP_svesq(embedding, 1, &sumSquares, vDSP_Length(embedding.count))
let magnitude = sumSquares.squareRoot()
guard magnitude > minMagnitude else { guard magnitude > minMagnitude else {
logger.warning("Low magnitude embedding: \(magnitude)") logger.warning("Low magnitude embedding: \(magnitude)")
return false return false
@@ -232,11 +249,12 @@ public enum SpeakerUtilities {
} }
// Create validated parameters // Create validated parameters
let normalizedEmbedding = VDSPOperations.l2Normalize(embedding)
let params = SpeakerCreationParams( let params = SpeakerCreationParams(
id: id, id: id,
name: name, name: name,
duration: duration, duration: duration,
embedding: embedding embedding: normalizedEmbedding
) )
return .success(params) return .success(params)
@@ -293,12 +311,15 @@ public enum SpeakerUtilities {
} }
// Calculate exponential moving average // Calculate exponential moving average
var updated = [Float](repeating: 0, count: current.count) let normalizedCurrent = VDSPOperations.l2Normalize(current)
for i in 0..<current.count { let normalizedNew = VDSPOperations.l2Normalize(new)
updated[i] = alpha * current[i] + (1 - alpha) * new[i]
var updated = [Float](repeating: 0, count: normalizedCurrent.count)
for i in 0..<normalizedCurrent.count {
updated[i] = alpha * normalizedCurrent[i] + (1 - alpha) * normalizedNew[i]
} }
return updated return VDSPOperations.l2Normalize(updated)
} }
// MARK: - Raw Embedding Management // MARK: - Raw Embedding Management
@@ -319,9 +340,10 @@ public enum SpeakerUtilities {
} }
// Create the new raw embedding // Create the new raw embedding
let normalizedEmbedding = VDSPOperations.l2Normalize(embedding)
let newEmbedding = RawEmbedding( let newEmbedding = RawEmbedding(
segmentId: segmentId, segmentId: segmentId,
embedding: embedding, embedding: normalizedEmbedding,
timestamp: timestamp timestamp: timestamp
) )
@@ -385,12 +407,14 @@ public enum SpeakerUtilities {
return nil return nil
} }
let normalizedEmbedding = VDSPOperations.l2Normalize(segmentEmbedding)
// Add to raw embeddings // Add to raw embeddings
guard guard
let (updatedRaw, shouldRecalc) = addRawEmbedding( let (updatedRaw, shouldRecalc) = addRawEmbedding(
to: currentRawEmbeddings, to: currentRawEmbeddings,
segmentId: segmentId, segmentId: segmentId,
embedding: segmentEmbedding, embedding: normalizedEmbedding,
timestamp: Date() timestamp: Date()
) )
else { else {
@@ -400,7 +424,7 @@ public enum SpeakerUtilities {
// Update main embedding using exponential moving average // Update main embedding using exponential moving average
let updatedMain = updateEmbedding( let updatedMain = updateEmbedding(
current: currentMainEmbedding, current: currentMainEmbedding,
new: segmentEmbedding, new: normalizedEmbedding,
alpha: alpha alpha: alpha
) )
@@ -471,7 +495,7 @@ public enum SpeakerUtilities {
average[i] /= Float(validCount) average[i] /= Float(validCount)
} }
return average return VDSPOperations.l2Normalize(average)
} }
} }
+16 -10
View File
@@ -1,3 +1,4 @@
import Accelerate
import Foundation import Foundation
/// Speaker profile representation for tracking speakers across audio /// Speaker profile representation for tracking speakers across audio
@@ -23,7 +24,7 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable {
let now = Date() let now = Date()
self.id = id ?? UUID().uuidString self.id = id ?? UUID().uuidString
self.name = name ?? self.id self.name = name ?? self.id
self.currentEmbedding = currentEmbedding self.currentEmbedding = VDSPOperations.l2Normalize(currentEmbedding)
self.duration = duration self.duration = duration
self.createdAt = createdAt ?? now self.createdAt = createdAt ?? now
self.updatedAt = updatedAt ?? now self.updatedAt = updatedAt ?? now
@@ -45,22 +46,26 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable {
) { ) {
// Validate embedding quality // Validate embedding quality
let embeddingMagnitude = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) var sumSquares: Float = 0
guard embeddingMagnitude > 0.1 else { return } vDSP_svesq(embedding, 1, &sumSquares, vDSP_Length(embedding.count))
guard sumSquares > 0.01 else { return }
let normalizedEmbedding = VDSPOperations.l2Normalize(embedding)
// Add to raw embeddings // Add to raw embeddings
let rawEmbedding = RawEmbedding( let rawEmbedding = RawEmbedding(
segmentId: segmentId, segmentId: segmentId,
embedding: embedding, embedding: normalizedEmbedding,
timestamp: Date() timestamp: Date()
) )
addRawEmbedding(rawEmbedding) addRawEmbedding(rawEmbedding)
// Update main embedding using exponential moving average // Update main embedding using exponential moving average
if currentEmbedding.count == embedding.count { if currentEmbedding.count == normalizedEmbedding.count {
for i in 0..<currentEmbedding.count { for i in 0..<currentEmbedding.count {
currentEmbedding[i] = alpha * currentEmbedding[i] + (1 - alpha) * embedding[i] currentEmbedding[i] = alpha * currentEmbedding[i] + (1 - alpha) * normalizedEmbedding[i]
} }
currentEmbedding = VDSPOperations.l2Normalize(currentEmbedding)
} }
// Update metadata // Update metadata
@@ -72,8 +77,9 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable {
/// Add a raw embedding with FIFO queue management /// Add a raw embedding with FIFO queue management
public func addRawEmbedding(_ embedding: RawEmbedding) { public func addRawEmbedding(_ embedding: RawEmbedding) {
// Validate embedding quality // Validate embedding quality
let embeddingMagnitude = sqrt(embedding.embedding.map { $0 * $0 }.reduce(0, +)) var sumSquares: Float = 0
guard embeddingMagnitude > 0.1 else { return } vDSP_svesq(embedding.embedding, 1, &sumSquares, vDSP_Length(embedding.embedding.count))
guard sumSquares > 0.01 else { return }
// Maintain max of 50 raw embeddings (FIFO) // Maintain max of 50 raw embeddings (FIFO)
if rawEmbeddings.count >= 50 { if rawEmbeddings.count >= 50 {
@@ -124,7 +130,7 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable {
averageEmbedding[i] /= count averageEmbedding[i] /= count
} }
self.currentEmbedding = averageEmbedding self.currentEmbedding = VDSPOperations.l2Normalize(averageEmbedding)
self.updatedAt = Date() self.updatedAt = Date()
} }
} }
@@ -177,7 +183,7 @@ public struct RawEmbedding: Codable, Sendable {
public init(segmentId: UUID = UUID(), embedding: [Float], timestamp: Date = Date()) { public init(segmentId: UUID = UUID(), embedding: [Float], timestamp: Date = Date()) {
self.segmentId = segmentId self.segmentId = segmentId
self.embedding = embedding self.embedding = VDSPOperations.l2Normalize(embedding)
self.timestamp = timestamp self.timestamp = timestamp
} }
} }
+31 -11
View File
@@ -100,7 +100,8 @@ public class DownloadUtils {
_ repo: Repo, _ repo: Repo,
modelNames: [String], modelNames: [String],
directory: URL, directory: URL,
computeUnits: MLComputeUnits = .cpuAndNeuralEngine computeUnits: MLComputeUnits = .cpuAndNeuralEngine,
variant: String? = nil
) async throws -> [String: MLModel] { ) async throws -> [String: MLModel] {
// Ensure host environment is logged for debugging (once per process) // Ensure host environment is logged for debugging (once per process)
await SystemInfo.logOnce(using: logger) await SystemInfo.logOnce(using: logger)
@@ -108,7 +109,7 @@ public class DownloadUtils {
// 1st attempt: normal load // 1st attempt: normal load
return try await loadModelsOnce( return try await loadModelsOnce(
repo, modelNames: modelNames, repo, modelNames: modelNames,
directory: directory, computeUnits: computeUnits) directory: directory, computeUnits: computeUnits, variant: variant)
} catch { } catch {
// 1st attempt failed wipe cache to signal redownload // 1st attempt failed wipe cache to signal redownload
logger.warning("First load failed: \(error.localizedDescription)") logger.warning("First load failed: \(error.localizedDescription)")
@@ -119,7 +120,7 @@ public class DownloadUtils {
// 2nd attempt after fresh download // 2nd attempt after fresh download
return try await loadModelsOnce( return try await loadModelsOnce(
repo, modelNames: modelNames, repo, modelNames: modelNames,
directory: directory, computeUnits: computeUnits) directory: directory, computeUnits: computeUnits, variant: variant)
} }
} }
@@ -134,7 +135,8 @@ public class DownloadUtils {
_ repo: Repo, _ repo: Repo,
modelNames: [String], modelNames: [String],
directory: URL, directory: URL,
computeUnits: MLComputeUnits = .cpuAndNeuralEngine computeUnits: MLComputeUnits = .cpuAndNeuralEngine,
variant: String? = nil
) async throws -> [String: MLModel] { ) async throws -> [String: MLModel] {
// Ensure host environment is logged for debugging (once per process) // Ensure host environment is logged for debugging (once per process)
await SystemInfo.logOnce(using: logger) await SystemInfo.logOnce(using: logger)
@@ -145,7 +147,7 @@ public class DownloadUtils {
let repoPath = directory.appendingPathComponent(repo.folderName) let repoPath = directory.appendingPathComponent(repo.folderName)
if !FileManager.default.fileExists(atPath: repoPath.path) { if !FileManager.default.fileExists(atPath: repoPath.path) {
logger.info("Models not found in cache at \(repoPath.path)") logger.info("Models not found in cache at \(repoPath.path)")
try await downloadRepo(repo, to: directory) try await downloadRepo(repo, to: directory, variant: variant)
} else { } else {
logger.info("Found \(repo.folderName) locally, no download needed") logger.info("Found \(repo.folderName) locally, no download needed")
} }
@@ -237,14 +239,14 @@ public class DownloadUtils {
} }
/// Download a HuggingFace repository /// Download a HuggingFace repository
private static func downloadRepo(_ repo: Repo, to directory: URL) async throws { private static func downloadRepo(_ repo: Repo, to directory: URL, variant: String? = nil) async throws {
logger.info("Downloading \(repo.folderName) from HuggingFace...") logger.info("Downloading \(repo.folderName) from HuggingFace...")
let repoPath = directory.appendingPathComponent(repo.folderName) let repoPath = directory.appendingPathComponent(repo.folderName)
try FileManager.default.createDirectory(at: repoPath, withIntermediateDirectories: true) try FileManager.default.createDirectory(at: repoPath, withIntermediateDirectories: true)
// Get the required model names for this repo // Get the required model names for this repo from the appropriate manager
let requiredModels = getRequiredModelNames(for: repo) let requiredModels = ModelNames.getRequiredModelNames(for: repo, variant: variant)
// Download all repository contents // Download all repository contents
let files = try await listRepoFiles(repo) let files = try await listRepoFiles(repo)
@@ -252,10 +254,28 @@ public class DownloadUtils {
for file in files { for file in files {
switch file.type { switch file.type {
case "directory" where file.path.hasSuffix(".mlmodelc"): case "directory" where file.path.hasSuffix(".mlmodelc"):
// Only download if this model is in our required list // Check if this model is required (with or without subfolder prefix)
if requiredModels.contains(file.path) { let isRequired =
requiredModels.contains(file.path) || requiredModels.contains { $0.hasSuffix("/" + file.path) }
if isRequired {
logger.info("Downloading required model: \(file.path)") logger.info("Downloading required model: \(file.path)")
try await downloadModelDirectory(repo: repo, dirPath: file.path, to: repoPath)
// Find if this should go in a subfolder
if let fullPath = requiredModels.first(where: { $0.hasSuffix("/" + file.path) }),
fullPath.contains("/")
{
// Extract subfolder (e.g., "speaker-diarization-offline/Segmentation.mlmodelc" -> "speaker-diarization-offline")
let subfolder = String(fullPath.split(separator: "/").first!)
let subfolderPath = repoPath.appendingPathComponent(subfolder)
try FileManager.default.createDirectory(at: subfolderPath, withIntermediateDirectories: true)
// Download to subfolder
try await downloadModelDirectory(repo: repo, dirPath: file.path, to: subfolderPath)
} else {
// Download to root of repo
try await downloadModelDirectory(repo: repo, dirPath: file.path, to: repoPath)
}
} else { } else {
logger.info("Skipping unrequired model: \(file.path)") logger.info("Skipping unrequired model: \(file.path)")
} }
+31 -1
View File
@@ -57,6 +57,33 @@ public enum ModelNames {
] ]
} }
/// Offline diarizer model names (VBx-based clustering)
public enum OfflineDiarizer {
public static let subfolder = "speaker-diarization-offline"
public static let segmentation = "Segmentation"
public static let fbank = "FBank"
public static let embedding = "Embedding"
public static let pldaRho = "PldaRho"
public static let segmentationFile = segmentation + ".mlmodelc"
public static let fbankFile = fbank + ".mlmodelc"
public static let embeddingFile = embedding + ".mlmodelc"
public static let pldaRhoFile = pldaRho + ".mlmodelc"
// Full paths including subfolder (for DownloadUtils)
public static let segmentationPath = subfolder + "/" + segmentationFile
public static let fbankPath = subfolder + "/" + fbankFile
public static let embeddingPath = subfolder + "/" + embeddingFile
public static let pldaRhoPath = subfolder + "/" + pldaRhoFile
public static let requiredModels: Set<String> = [
segmentationPath,
fbankPath,
embeddingPath,
pldaRhoPath,
]
}
/// ASR model names /// ASR model names
public enum ASR { public enum ASR {
public static let preprocessor = "Preprocessor" public static let preprocessor = "Preprocessor"
@@ -144,13 +171,16 @@ public enum ModelNames {
} }
} }
static func getRequiredModelNames(for repo: Repo) -> Set<String> { static func getRequiredModelNames(for repo: Repo, variant: String?) -> Set<String> {
switch repo { switch repo {
case .vad: case .vad:
return ModelNames.VAD.requiredModels return ModelNames.VAD.requiredModels
case .parakeet, .parakeetV2: case .parakeet, .parakeetV2:
return ModelNames.ASR.requiredModels return ModelNames.ASR.requiredModels
case .diarizer: case .diarizer:
if variant == "offline" {
return ModelNames.OfflineDiarizer.requiredModels
}
return ModelNames.Diarizer.requiredModels return ModelNames.Diarizer.requiredModels
case .kokoro: case .kokoro:
return ModelNames.TTS.requiredModels return ModelNames.TTS.requiredModels
+153
View File
@@ -0,0 +1,153 @@
import Accelerate
import CoreML
import Foundation
/// Lightweight helpers that exercise Core ML models with zero-valued inputs to
/// prime memory allocations before running the offline diarization pipeline.
enum ModelWarmup {
/// Performs a warmup loop for a model with a single MLMultiArray input.
///
/// - Parameters:
/// - model: Model to warm up.
/// - inputName: Feature name expected by the model.
/// - inputShape: Shape of the MLMultiArray (e.g. `[1, 1, 160_000]`).
/// - iterations: Number of times to execute `prediction`.
/// - Returns: Total elapsed duration in seconds.
static func warmup(
model: MLModel,
inputName: String,
inputShape: [Int],
iterations: Int = 1
) throws -> TimeInterval {
precondition(iterations > 0, "Warmup iterations must be positive")
precondition(!inputShape.isEmpty, "Input shape must not be empty")
let array = try MLMultiArray(
shape: inputShape.map { NSNumber(value: $0) },
dataType: .float32
)
array.resetToZeros()
let features = try MLDictionaryFeatureProvider(dictionary: [
inputName: MLFeatureValue(multiArray: array)
])
let start = Date()
for _ in 0..<iterations {
_ = try model.prediction(from: features)
}
return Date().timeIntervalSince(start)
}
/// Warm up the embedding extractor with representative audio + weight inputs.
///
/// We reproduce the exact shapes used during inference to make sure Core ML
/// allocates and caches buffers on the correct compute units (ANE/GPU).
static func warmupEmbeddingModel(
_ model: MLModel,
weightFrames: Int,
audioSamples: Int = 160_000
) throws {
precondition(weightFrames > 0, "weightFrames must be positive")
do {
let inputs = model.modelDescription.inputDescriptionsByName
let featureShape: [Int]
if let fbank = inputs.first(where: { $0.key.caseInsensitiveCompare("fbank_features") == .orderedSame })?
.value.multiArrayConstraint?.shape
{
let mapped = fbank.map { $0.intValue }
if !mapped.isEmpty, mapped.allSatisfy({ $0 > 0 }) {
featureShape = mapped
} else {
featureShape = [1, 1, 80, 998]
}
} else {
featureShape = [1, 1, 80, 998]
}
let weightsShape: [Int]
if let weights = inputs.first(where: { $0.key.caseInsensitiveCompare("weights") == .orderedSame })?
.value.multiArrayConstraint?.shape
{
let mapped = weights.map { $0.intValue }
if !mapped.isEmpty, mapped.allSatisfy({ $0 > 0 }) {
weightsShape = mapped
} else {
weightsShape = [1, weightFrames]
}
} else {
weightsShape = [1, weightFrames]
}
let featureArray = try MLMultiArray(
shape: featureShape.map { NSNumber(value: $0) },
dataType: .float32
)
featureArray.resetToZeros()
let weightArray = try MLMultiArray(
shape: weightsShape.map { NSNumber(value: $0) },
dataType: .float32
)
weightArray.resetToZeros()
let provider = try MLDictionaryFeatureProvider(dictionary: [
"fbank_features": MLFeatureValue(multiArray: featureArray),
"weights": MLFeatureValue(multiArray: weightArray),
])
_ = try model.prediction(from: provider)
return
} catch {
// Fall back to combined legacy interface.
}
let totalElements = audioSamples + weightFrames
do {
let combinedArray = try MLMultiArray(
shape: [1, 1, 1, NSNumber(value: totalElements)],
dataType: .float32
)
combinedArray.resetToZeros()
let provider = try MLDictionaryFeatureProvider(dictionary: [
"audio_and_weights": MLFeatureValue(multiArray: combinedArray)
])
_ = try model.prediction(from: provider)
return
} catch {
// Fall through to legacy dual-input warmup for older embedding models.
}
let audioArray = try MLMultiArray(
shape: [1, 1, NSNumber(value: audioSamples)],
dataType: .float32
)
audioArray.resetToZeros()
let weightArray = try MLMultiArray(
shape: [1, NSNumber(value: weightFrames)],
dataType: .float32
)
weightArray.resetToZeros()
let provider = try MLDictionaryFeatureProvider(dictionary: [
"audio": MLFeatureValue(multiArray: audioArray),
"weights": MLFeatureValue(multiArray: weightArray),
])
_ = try model.prediction(from: provider)
}
}
extension MLMultiArray {
fileprivate func resetToZeros() {
let pointer = dataPointer.assumingMemoryBound(to: Float.self)
let count = self.count
var zero: Float = 0
vDSP_vfill(&zero, pointer, 1, vDSP_Length(count))
}
}
@@ -0,0 +1,81 @@
import Foundation
public protocol StreamingAudioSampleSource: Sendable {
var sampleCount: Int { get }
func copySamples(
into destination: UnsafeMutablePointer<Float>,
offset: Int,
count: Int
) throws
}
public struct ArrayAudioSampleSource: StreamingAudioSampleSource {
private let samples: [Float]
public init(samples: [Float]) {
self.samples = samples
}
public var sampleCount: Int {
samples.count
}
public func copySamples(
into destination: UnsafeMutablePointer<Float>,
offset: Int,
count: Int
) throws {
guard count > 0 else { return }
guard !samples.isEmpty else { return }
let clampedOffset = max(0, offset)
guard clampedOffset < samples.count else { return }
let available = min(samples.count - clampedOffset, count)
samples.withUnsafeBufferPointer { pointer in
destination.update(
from: pointer.baseAddress!.advanced(by: clampedOffset),
count: available
)
}
}
}
public struct DiskBackedAudioSampleSource: StreamingAudioSampleSource {
private let mappedData: Data
private let floatStride = MemoryLayout<Float>.stride
private let fileURL: URL
public let sampleCount: Int
init(mappedData: Data, fileURL: URL) {
self.mappedData = mappedData
self.fileURL = fileURL
self.sampleCount = mappedData.count / floatStride
}
public func copySamples(
into destination: UnsafeMutablePointer<Float>,
offset: Int,
count: Int
) throws {
guard count > 0 else { return }
guard sampleCount > 0 else { return }
let clampedOffset = max(0, offset)
guard clampedOffset < sampleCount else { return }
let available = min(sampleCount - clampedOffset, count)
mappedData.withUnsafeBytes { rawBuffer in
let floatBuffer = rawBuffer.bindMemory(to: Float.self)
destination.update(
from: floatBuffer.baseAddress!.advanced(by: clampedOffset),
count: available
)
}
}
public func cleanup() {
do {
try FileManager.default.removeItem(at: fileURL)
} catch {
// Silently ignore cleanup failures; temporary directory will be purged eventually.
}
}
}
@@ -0,0 +1,213 @@
import AVFoundation
import Foundation
import OSLog
public struct StreamingAudioSourceFactory {
private let logger = AppLogger(category: "StreamingAudioSourceFactory")
public init() {}
public func makeDiskBackedSource(
from url: URL,
targetSampleRate: Int
) throws -> (source: DiskBackedAudioSampleSource, loadDuration: TimeInterval) {
do {
let startTime = Date()
let audioFile = try AVAudioFile(forReading: url)
let inputFormat = audioFile.processingFormat
let targetFormat = AVAudioFormat(
commonFormat: .pcmFormatFloat32,
sampleRate: Double(targetSampleRate),
channels: 1,
interleaved: false
)!
let tempURL = try makeTemporaryURL()
guard FileManager.default.createFile(atPath: tempURL.path, contents: nil) else {
throw StreamingAudioError.processingFailed("Failed to create temporary audio buffer at \(tempURL.path)")
}
let handle = try FileHandle(forWritingTo: tempURL)
defer {
try? handle.close()
}
guard let converter = AVAudioConverter(from: inputFormat, to: targetFormat) else {
throw StreamingAudioError.processingFailed(
"Unsupported audio format \(inputFormat); failed to create converter")
}
logger.debug(
"Streaming conversion \(inputFormat.sampleRate)Hz×\(inputFormat.channelCount)ch → \(targetFormat.sampleRate)Hz×\(targetFormat.channelCount)ch"
)
let totalSamples: Int
do {
totalSamples = try streamConvert(
audioFile: audioFile,
converter: converter,
handle: handle
)
} catch {
logger.error("Streaming conversion failed before file mapping: \(error.localizedDescription)")
throw error
}
try handle.synchronize()
try handle.close()
let attributes = try FileManager.default.attributesOfItem(atPath: tempURL.path)
if let fileSize = attributes[.size] as? NSNumber {
logger.debug("Streaming audio temp file size=\(fileSize.intValue) bytes")
}
logger.debug("Streaming audio total samples=\(totalSamples)")
let mappedData = try Data(contentsOf: tempURL, options: [.mappedIfSafe])
let source = DiskBackedAudioSampleSource(mappedData: mappedData, fileURL: tempURL)
if source.sampleCount != totalSamples {
logger.warning(
"Mapped sample count mismatch (reported=\(source.sampleCount), tracked=\(totalSamples)); continuing"
)
}
let duration = Date().timeIntervalSince(startTime)
return (source, duration)
} catch let streamingError as StreamingAudioError {
throw streamingError
} catch {
logger.error("Streaming audio source creation failed: \(error.localizedDescription)")
throw StreamingAudioError.processingFailed(
"Streaming audio source creation failed: \(error.localizedDescription)"
)
}
}
private func makeTemporaryURL() throws -> URL {
let tempDirectory = FileManager.default.temporaryDirectory
let identifier = UUID().uuidString
return tempDirectory.appendingPathComponent("fluidaudio-streaming-\(identifier).raw")
}
private func streamConvert(
audioFile: AVAudioFile,
converter: AVAudioConverter,
handle: FileHandle
) throws -> Int {
let inputFormat = audioFile.processingFormat
let targetFormat = converter.outputFormat
let inputCapacity: AVAudioFrameCount = 16_384
guard
let inputBuffer = AVAudioPCMBuffer(
pcmFormat: inputFormat,
frameCapacity: inputCapacity
)
else {
throw StreamingAudioError.failedToAllocateBuffer("Input", requestedFrames: Int(inputCapacity))
}
let estimatedOutputFrames = AVAudioFrameCount(
(Double(inputCapacity) * targetFormat.sampleRate / inputFormat.sampleRate).rounded(.up)
)
guard
let outputBuffer = AVAudioPCMBuffer(
pcmFormat: targetFormat,
frameCapacity: max(1024, estimatedOutputFrames)
)
else {
throw StreamingAudioError.failedToAllocateBuffer("Output", requestedFrames: Int(estimatedOutputFrames))
}
var totalSamples = 0
var inputComplete = false
var readError: Error?
let inputBlock: AVAudioConverterInputBlock = { _, status in
if inputComplete {
status.pointee = .endOfStream
return nil
}
do {
let remainingFrames = AVAudioFrameCount(audioFile.length - audioFile.framePosition)
let framesToRead = min(inputCapacity, remainingFrames)
if framesToRead > 0 {
try audioFile.read(into: inputBuffer, frameCount: framesToRead)
} else {
inputBuffer.frameLength = 0
}
} catch {
readError = error
inputBuffer.frameLength = 0
}
guard inputBuffer.frameLength > 0 else {
inputComplete = true
status.pointee = .endOfStream
return nil
}
status.pointee = .haveData
return inputBuffer
}
while true {
outputBuffer.frameLength = 0
var conversionError: NSError?
let status = converter.convert(
to: outputBuffer,
error: &conversionError,
withInputFrom: inputBlock
)
if let conversionError {
throw StreamingAudioError.processingFailed(
"Audio conversion failed: \(conversionError.localizedDescription)"
)
}
if let readError {
throw StreamingAudioError.processingFailed(
"Failed while reading audio: \(readError.localizedDescription)"
)
}
let producedFrames = Int(outputBuffer.frameLength)
if producedFrames > 0 {
guard let channelData = outputBuffer.floatChannelData?.pointee else {
throw StreamingAudioError.processingFailed("Missing channel data during conversion")
}
let byteCount = producedFrames * MemoryLayout<Float>.stride
let baseAddress = UnsafeRawPointer(channelData)
let data = Data(bytes: baseAddress, count: byteCount)
try handle.write(contentsOf: data)
totalSamples += producedFrames
}
if status == .endOfStream {
break
}
}
return totalSamples
}
}
public enum StreamingAudioError: Error, LocalizedError {
case processingFailed(String)
public var errorDescription: String? {
switch self {
case .processingFailed(let message):
return "Processing failed: \(message)"
}
}
}
extension StreamingAudioError {
fileprivate static func failedToAllocateBuffer(_ name: String, requestedFrames: Int) -> StreamingAudioError {
.processingFailed("Failed to allocate \(name.lowercased()) buffer (\(requestedFrames) frames)")
}
}
@@ -36,22 +36,22 @@ enum StreamDiarizationBenchmark {
static func printUsage() { static func printUsage() {
logger.info( logger.info(
""" """
Stream Diarization Benchmark Command Diarization Benchmark Command
Evaluates streaming speaker diarization WITHOUT retroactive speaker remapping. Evaluates speaker diarization in either streaming (online) or offline (VBx) mode.
This measures true real-time performance as seen in production systems.
Usage: fluidaudio diarization-benchmark [options] Usage: fluidaudio diarization-benchmark [options]
Options: Options:
--mode <streaming|offline> Diarization mode (default: streaming)
--dataset <name> Dataset to benchmark (default: ami-sdm) --dataset <name> Dataset to benchmark (default: ami-sdm)
--single-file <name> Process a specific meeting (e.g., ES2004a) --single-file <name> Process a specific meeting (e.g., ES2004a)
--max-files <n> Maximum number of files to process --max-files <n> Maximum number of files to process
--chunk-seconds <sec> Chunk duration for streaming (default: 10.0) --chunk-seconds <sec> Chunk duration for streaming (default: 10.0, streaming only)
--overlap-seconds <sec> Overlap between chunks (default: 0.0) --overlap-seconds <sec> Overlap between chunks (default: 0.0, streaming only)
--threshold <value> Clustering threshold (default: 0.7) --threshold <value> Clustering threshold (default: 0.7)
--assignment-threshold Threshold for assigning to existing speakers (default: 0.84) --assignment-threshold Threshold for assigning to existing speakers (default: 0.84, streaming only)
--update-threshold Threshold for updating speaker embeddings (default: 0.56) --update-threshold Threshold for updating speaker embeddings (default: 0.56, streaming only)
--output <file> Output JSON file for results --output <file> Output JSON file for results
--csv <file> Output CSV file for summary --csv <file> Output CSV file for summary
--verbose Enable verbose output --verbose Enable verbose output
@@ -60,6 +60,10 @@ enum StreamDiarizationBenchmark {
--iterations <n> Number of iterations per file (default: 1) --iterations <n> Number of iterations per file (default: 1)
--help Show this help message --help Show this help message
Modes:
streaming Online diarization with chunk-based processing (first-occurrence speaker mapping)
offline Batch diarization with VBx clustering (optimal speaker mapping with Hungarian algorithm)
Streaming Modes (via chunk/overlap settings): Streaming Modes (via chunk/overlap settings):
Real-time: --chunk-seconds 3 --overlap-seconds 2 (~15-30x RTFx) Real-time: --chunk-seconds 3 --overlap-seconds 2 (~15-30x RTFx)
Balanced: --chunk-seconds 10 --overlap-seconds 5 (~70x RTFx) Balanced: --chunk-seconds 10 --overlap-seconds 5 (~70x RTFx)
@@ -67,24 +71,27 @@ enum StreamDiarizationBenchmark {
Performance Targets: Performance Targets:
DER < 30% (competitive with research systems) DER < 30% (competitive with research systems)
RTFx > 1x (real-time capable) RTFx > 1x (real-time capable, streaming mode)
Examples: Examples:
# Benchmark single file with real-time settings # Offline VBx clustering (research-grade accuracy)
fluidaudio diarization-benchmark --single-file ES2004a \\ fluidaudio diarization-benchmark --mode offline --single-file ES2004a
# Streaming mode with real-time settings
fluidaudio diarization-benchmark --mode streaming --single-file ES2004a \\
--chunk-seconds 3 --overlap-seconds 2 --chunk-seconds 3 --overlap-seconds 2
# Full AMI benchmark with balanced settings # Full AMI benchmark in offline mode
fluidaudio diarization-benchmark --dataset ami-sdm \\ fluidaudio diarization-benchmark --mode offline --dataset ami-sdm --csv results.csv
--chunk-seconds 10 --overlap-seconds 5 --csv results.csv
# Quick test on 5 files # Quick test on 5 files (offline)
fluidaudio diarization-benchmark --max-files 5 --verbose fluidaudio diarization-benchmark --mode offline --max-files 5 --verbose
""") """)
} }
static func run(arguments: [String]) async { static func run(arguments: [String]) async {
// Parse arguments // Parse arguments
var mode = "streaming" // Default to streaming mode
var dataset = "ami-sdm" var dataset = "ami-sdm"
var singleFile: String? var singleFile: String?
var maxFiles: Int? var maxFiles: Int?
@@ -103,6 +110,11 @@ enum StreamDiarizationBenchmark {
var i = 0 var i = 0
while i < arguments.count { while i < arguments.count {
switch arguments[i] { switch arguments[i] {
case "--mode":
if i + 1 < arguments.count {
mode = arguments[i + 1]
i += 1
}
case "--dataset": case "--dataset":
if i + 1 < arguments.count { if i + 1 < arguments.count {
dataset = arguments[i + 1] dataset = arguments[i + 1]
@@ -175,29 +187,43 @@ enum StreamDiarizationBenchmark {
i += 1 i += 1
} }
// Validate settings // Validate mode
let hopSize = max(chunkSeconds - overlapSeconds, 1.0) guard mode == "streaming" || mode == "offline" else {
let overlapRatio = overlapSeconds / chunkSeconds logger.error("Invalid mode: \(mode). Must be 'streaming' or 'offline'")
printUsage()
logger.info("🚀 Starting Stream Diarization Benchmark") return
logger.info(" Dataset: \(dataset)")
logger.info(" Chunk size: \(chunkSeconds)s")
logger.info(" Overlap: \(overlapSeconds)s (\(String(format: "%.0f", overlapRatio * 100))%)")
logger.info(" Hop size: \(hopSize)s")
logger.info(" Clustering threshold: \(threshold)")
logger.info(" Assignment threshold: \(assignmentThreshold)")
logger.info(" Update threshold: \(updateThreshold)")
// Determine streaming mode
let mode: String
if overlapSeconds == 0 {
mode = "Batch (no overlap)"
} else if overlapRatio >= 0.6 {
mode = "Real-time (high overlap)"
} else {
mode = "Balanced"
} }
logger.info(" Mode: \(mode)\n")
logger.info("🚀 Starting Diarization Benchmark (\(mode.uppercased()) MODE)")
logger.info(" Dataset: \(dataset)")
logger.info(" Clustering threshold: \(threshold)")
if mode == "streaming" {
// Validate streaming settings
let hopSize = max(chunkSeconds - overlapSeconds, 1.0)
let overlapRatio = overlapSeconds / chunkSeconds
logger.info(" Chunk size: \(chunkSeconds)s")
logger.info(" Overlap: \(overlapSeconds)s (\(String(format: "%.0f", overlapRatio * 100))%)")
logger.info(" Hop size: \(hopSize)s")
logger.info(" Assignment threshold: \(assignmentThreshold)")
logger.info(" Update threshold: \(updateThreshold)")
// Determine streaming mode
let streamingMode: String
if overlapSeconds == 0 {
streamingMode = "Batch (no overlap)"
} else if overlapRatio >= 0.6 {
streamingMode = "Real-time (high overlap)"
} else {
streamingMode = "Balanced"
}
logger.info(" Streaming mode: \(streamingMode)")
} else {
logger.info(" Using VBx clustering with optimal speaker mapping")
}
logger.info("")
// Download dataset if needed // Download dataset if needed
if autoDownload { if autoDownload {
@@ -230,8 +256,22 @@ enum StreamDiarizationBenchmark {
logger.info("🔧 Initializing models...") logger.info("🔧 Initializing models...")
let modelStartTime = Date() let modelStartTime = Date()
let models: DiarizerModels let models: DiarizerModels
var offlineManager: OfflineDiarizerManager?
do { do {
models = try await DiarizerModels.downloadIfNeeded() models = try await DiarizerModels.downloadIfNeeded()
// For offline mode, also initialize the offline manager
if mode == "offline" {
let modelDir = OfflineDiarizerModels.defaultModelsDirectory()
let offlineConfig = OfflineDiarizerConfig(
clusteringThreshold: Double(threshold)
)
offlineManager = OfflineDiarizerManager(config: offlineConfig)
let offlineModels = try await OfflineDiarizerModels.load(from: modelDir)
offlineManager?.initialize(models: offlineModels)
logger.info("✅ Offline manager initialized")
}
} catch { } catch {
logger.error("❌ Failed to initialize models: \(error)") logger.error("❌ Failed to initialize models: \(error)")
return return
@@ -254,18 +294,31 @@ enum StreamDiarizationBenchmark {
logger.info(" Iteration \(iteration)/\(iterations)") logger.info(" Iteration \(iteration)/\(iterations)")
} }
if let result = await processMeeting( let result: BenchmarkResult?
meetingName: meetingName, if mode == "streaming" {
models: models, result = await processStreamingMeeting(
modelInitTime: modelInitTime, meetingName: meetingName,
chunkSeconds: chunkSeconds, models: models,
overlapSeconds: overlapSeconds, modelInitTime: modelInitTime,
threshold: threshold, chunkSeconds: chunkSeconds,
assignmentThreshold: assignmentThreshold, overlapSeconds: overlapSeconds,
updateThreshold: updateThreshold, threshold: threshold,
verbose: verbose, assignmentThreshold: assignmentThreshold,
debugMode: debugMode updateThreshold: updateThreshold,
) { verbose: verbose,
debugMode: debugMode
)
} else {
result = await processOfflineMeeting(
meetingName: meetingName,
controller: offlineManager!,
modelInitTime: modelInitTime,
verbose: verbose,
debugMode: debugMode
)
}
if let result = result {
iterationResults.append(result) iterationResults.append(result)
// Print summary for this iteration // Print summary for this iteration
@@ -340,7 +393,7 @@ enum StreamDiarizationBenchmark {
} }
} }
private static func processMeeting( private static func processStreamingMeeting(
meetingName: String, meetingName: String,
models: DiarizerModels, models: DiarizerModels,
modelInitTime: Double, modelInitTime: Double,
@@ -543,6 +596,104 @@ enum StreamDiarizationBenchmark {
} }
} }
private static func processOfflineMeeting(
meetingName: String,
controller: OfflineDiarizerManager,
modelInitTime: Double,
verbose: Bool,
debugMode: Bool
) async -> BenchmarkResult? {
// Load audio
let audioPath = getAudioPath(for: meetingName)
guard FileManager.default.fileExists(atPath: audioPath) else {
logger.error("❌ Audio file not found: \(audioPath)")
return nil
}
do {
// Track audio loading time
let audioLoadStart = Date()
let audioData = try await loadAudioFile(at: audioPath)
let audioLoadTime = Date().timeIntervalSince(audioLoadStart)
let totalDuration = Double(audioData.count) / 16000.0
if verbose {
logger.info(" Audio duration: \(String(format: "%.1f", totalDuration))s")
logger.info(" Audio load time: \(String(format: "%.3f", audioLoadTime))s")
}
// Process with offline controller
let startTime = Date()
let result = try await controller.process(audio: audioData)
let totalElapsed = Date().timeIntervalSince(startTime)
let finalRTFx = totalDuration / totalElapsed
if verbose {
logger.info(" Processing time: \(String(format: "%.3f", totalElapsed))s")
logger.info(" RTFx: \(String(format: "%.1f", finalRTFx))x")
}
// Load ground truth
let groundTruth = await AMIParser.loadAMIGroundTruth(
for: meetingName,
duration: Float(totalDuration)
)
guard !groundTruth.isEmpty else {
logger.warning("⚠️ No ground truth found for \(meetingName)")
return nil
}
// Calculate metrics with Hungarian algorithm (optimal mapping for offline)
let metrics = DiarizationMetricsCalculator.offlineMetrics(
predicted: result.segments,
groundTruth: groundTruth,
frameSize: 0.01,
audioDurationSeconds: totalDuration,
logger: logger
)
// Extract timing breakdown if available
let segmentationTime = result.timings?.segmentationSeconds ?? 0
let embeddingTime = result.timings?.embeddingExtractionSeconds ?? 0
let clusteringTime = result.timings?.speakerClusteringSeconds ?? 0
let totalInferenceTime = segmentationTime + embeddingTime + clusteringTime
// Count detected speakers
let detectedSpeakers = Set(result.segments.map { $0.speakerId }).count
return BenchmarkResult(
meetingName: meetingName,
der: metrics.der,
missRate: metrics.missRate,
falseAlarmRate: metrics.falseAlarmRate,
speakerErrorRate: metrics.speakerErrorRate,
jer: metrics.jer,
rtfx: Float(finalRTFx),
processingTime: totalElapsed,
chunksProcessed: 1, // Offline processes entire file at once
detectedSpeakers: detectedSpeakers,
groundTruthSpeakers: AMIParser.getGroundTruthSpeakerCount(for: meetingName),
speakerFragmentation: 1.0, // No fragmentation in offline mode
latency90th: totalElapsed,
latency99th: totalElapsed,
// Timing breakdown
modelDownloadTime: modelInitTime * 0.7,
modelCompileTime: modelInitTime * 0.3,
audioLoadTime: audioLoadTime,
segmentationTime: segmentationTime,
embeddingTime: embeddingTime,
clusteringTime: clusteringTime,
totalInferenceTime: totalInferenceTime
)
} catch {
logger.error("❌ Error processing \(meetingName): \(error)")
return nil
}
}
/// Calculate DER metrics with first-occurrence mapping for streaming evaluation /// Calculate DER metrics with first-occurrence mapping for streaming evaluation
private static func calculateStreamingMetrics( private static func calculateStreamingMetrics(
predicted: [TimedSpeakerSegment], predicted: [TimedSpeakerSegment],
@@ -1,26 +1,39 @@
#if os(macOS) #if os(macOS)
import AVFoundation import AVFoundation
import FluidAudio import FluidAudio
import Foundation
var standardError = FileHandle.standardError
/// Handler for the 'process' command - processes a single audio file /// Handler for the 'process' command - processes a single audio file
enum ProcessCommand { enum ProcessCommand {
private static let logger = AppLogger(category: "Process") private static let logger = AppLogger(category: "Process")
static func run(arguments: [String]) async { static func run(arguments: [String]) async {
guard !arguments.isEmpty else { guard !arguments.isEmpty else {
fputs("ERROR: No audio file specified\n", stderr)
fflush(stderr)
logger.error("No audio file specified") logger.error("No audio file specified")
printUsage() printUsage()
exit(1) exit(1)
} }
let audioFile = arguments[0] let audioFile = arguments[0]
var threshold: Float = 0.7 var mode = "streaming" // Default to streaming
var threshold: Float = 0.7045655 // PyAnnote community-1 default
var debugMode = false var debugMode = false
var outputFile: String? var outputFile: String?
var rttmFile: String?
var embeddingExportPath: String?
// Parse remaining arguments // Parse remaining arguments
var i = 1 var i = 1
while i < arguments.count { while i < arguments.count {
switch arguments[i] { switch arguments[i] {
case "--mode":
if i + 1 < arguments.count {
mode = arguments[i + 1]
i += 1
}
case "--threshold": case "--threshold":
if i + 1 < arguments.count { if i + 1 < arguments.count {
threshold = Float(arguments[i + 1]) ?? 0.8 threshold = Float(arguments[i + 1]) ?? 0.8
@@ -33,71 +46,184 @@ enum ProcessCommand {
outputFile = arguments[i + 1] outputFile = arguments[i + 1]
i += 1 i += 1
} }
case "--rttm":
if i + 1 < arguments.count {
rttmFile = arguments[i + 1]
i += 1
}
case "--export-embeddings":
if i + 1 < arguments.count {
embeddingExportPath = arguments[i + 1]
i += 1
}
default: default:
logger.warning("Unknown option: \(arguments[i])") logger.warning("Unknown option: \(arguments[i])")
} }
i += 1 i += 1
} }
logger.info("🎵 Processing audio file: \(audioFile)") // Validate mode
logger.info(" Clustering threshold: \(threshold)") guard mode == "streaming" || mode == "offline" else {
fputs("ERROR: Invalid mode: \(mode)\n", stderr)
let config = DiarizerConfig( fflush(stderr)
clusteringThreshold: threshold, logger.error("Invalid mode: \(mode). Must be 'streaming' or 'offline'")
debugMode: debugMode printUsage()
)
let manager = DiarizerManager(config: config)
do {
let models = try await DiarizerModels.downloadIfNeeded()
manager.initialize(models: models)
logger.info("Models initialized")
} catch {
logger.error("Failed to initialize models: \(error)")
exit(1) exit(1)
} }
// Load and process audio file logger.info("🎵 Processing audio file (\(mode.uppercased()) MODE): \(audioFile)")
do { logger.info(" Clustering threshold: \(threshold)")
let audioSamples = try AudioConverter().resampleAudioFile(path: audioFile)
logger.info("Loaded audio: \(audioSamples.count) samples")
let startTime = Date() if mode == "streaming" {
let result = try manager.performCompleteDiarization( // Streaming mode - use DiarizerManager
audioSamples, sampleRate: 16000) let config = DiarizerConfig(
let processingTime = Date().timeIntervalSince(startTime) clusteringThreshold: threshold,
debugMode: debugMode
let duration = Float(audioSamples.count) / 16000.0
let rtfx = duration / Float(processingTime)
logger.info("Diarization completed in \(String(format: "%.1f", processingTime))s")
logger.info(" Real-time factor (RTFx): \(String(format: "%.2f", rtfx))x")
logger.info(" Found \(result.segments.count) segments")
logger.info(" Detected \(result.speakerDatabase?.count ?? 0) speakers (total), mapped: TBD")
// Create output
let output = ProcessingResult(
audioFile: audioFile,
durationSeconds: duration,
processingTimeSeconds: processingTime,
realTimeFactor: rtfx,
segments: result.segments,
speakerCount: result.speakerDatabase?.count ?? 0,
config: config
) )
// Output results let manager = DiarizerManager(config: config)
if let outputFile = outputFile {
try await ResultsFormatter.saveResults(output, to: outputFile) do {
logger.info("💾 Results saved to: \(outputFile)") let models = try await DiarizerModels.downloadIfNeeded()
} else { manager.initialize(models: models)
await ResultsFormatter.printResults(output) logger.info("Models initialized")
} catch {
logger.error("Failed to initialize models: \(error)")
exit(1)
} }
} catch { // Load and process audio file
logger.error("Failed to process audio file: \(error)") do {
exit(1) let audioSamples = try AudioConverter().resampleAudioFile(path: audioFile)
logger.info("Loaded audio: \(audioSamples.count) samples")
let startTime = Date()
let result = try manager.performCompleteDiarization(
audioSamples, sampleRate: 16000)
let processingTime = Date().timeIntervalSince(startTime)
let duration = Float(audioSamples.count) / 16000.0
let rtfx = duration / Float(processingTime)
logger.info("Diarization completed in \(String(format: "%.1f", processingTime))s")
logger.info(" Real-time factor (RTFx): \(String(format: "%.2f", rtfx))x")
logger.info(" Found \(result.segments.count) segments")
logger.info(" Detected \(result.speakerDatabase?.count ?? 0) speakers")
// Create output
let output = ProcessingResult(
audioFile: audioFile,
durationSeconds: duration,
processingTimeSeconds: processingTime,
realTimeFactor: rtfx,
segments: result.segments,
speakerCount: result.speakerDatabase?.count ?? 0,
config: config,
metrics: nil,
timings: result.timings
)
// Output results
if let outputFile = outputFile {
try await ResultsFormatter.saveResults(output, to: outputFile)
logger.info("💾 Results saved to: \(outputFile)")
} else {
await ResultsFormatter.printResults(output)
}
} catch {
logger.error("Failed to process audio file: \(error)")
exit(1)
}
} else {
// Offline mode - use OfflineDiarizerManager
do {
let modelDir = OfflineDiarizerModels.defaultModelsDirectory()
let offlineConfig = OfflineDiarizerConfig(
clusteringThreshold: Double(threshold),
embeddingExportPath: embeddingExportPath
)
let manager = OfflineDiarizerManager(config: offlineConfig)
let models = try await OfflineDiarizerModels.load(from: modelDir)
manager.initialize(models: models)
logger.info("Offline manager initialized")
// Load and process audio file without materializing the full sample buffer.
let audioURL = URL(fileURLWithPath: audioFile)
let factory = StreamingAudioSourceFactory()
let targetSampleRate = offlineConfig.segmentation.sampleRate
let diskSourceResult = try factory.makeDiskBackedSource(
from: audioURL,
targetSampleRate: targetSampleRate
)
let diskSource = diskSourceResult.source
defer { diskSource.cleanup() }
let loadDurationText = String(format: "%.2f", diskSourceResult.loadDuration)
logger.info(
"Prepared disk-backed audio source: \(diskSource.sampleCount) samples (\(loadDurationText)s)")
let startTime = Date()
let result = try await manager.process(
audioSource: diskSource,
audioLoadingSeconds: diskSourceResult.loadDuration
)
let processingTime = Date().timeIntervalSince(startTime)
let durationSeconds = Double(diskSource.sampleCount) / Double(targetSampleRate)
let rtfx = durationSeconds / processingTime
logger.info("Diarization completed in \(String(format: "%.1f", processingTime))s")
logger.info(" Real-time factor (RTFx): \(String(format: "%.2f", rtfx))x")
logger.info(" Found \(result.segments.count) segments")
let speakerCount = Set(result.segments.map { $0.speakerId }).count
logger.info(" Detected \(speakerCount) speakers")
var metrics: DiarizationMetrics?
if let rttmFile = rttmFile {
do {
let groundTruth = try RTTMParser.loadSegments(from: rttmFile)
metrics = DiarizationMetricsCalculator.offlineMetrics(
predicted: result.segments,
groundTruth: groundTruth,
frameSize: 0.01,
audioDurationSeconds: durationSeconds,
logger: logger
)
} catch {
logger.error("Failed to compute offline metrics: \(error.localizedDescription)")
}
}
// Create simplified output for offline mode
let output = ProcessingResult(
audioFile: audioFile,
durationSeconds: Float(durationSeconds),
processingTimeSeconds: processingTime,
realTimeFactor: Float(rtfx),
segments: result.segments,
speakerCount: speakerCount,
config: nil,
metrics: metrics,
timings: result.timings
)
// Output results
if let outputFile = outputFile {
try await ResultsFormatter.saveResults(output, to: outputFile)
logger.info("💾 Results saved to: \(outputFile)")
} else {
await ResultsFormatter.printResults(output)
}
} catch {
fputs("ERROR: Failed to process audio file (offline mode): \(error)\n", stderr)
fflush(stderr)
logger.error("Failed to process audio file (offline mode): \(error)")
exit(1)
}
} }
} }
@@ -109,12 +235,23 @@ enum ProcessCommand {
fluidaudio process <audio_file> [options] fluidaudio process <audio_file> [options]
Options: Options:
--threshold <float> Clustering threshold (default: 0.8) --mode <streaming|offline> Diarization mode (default: streaming)
--debug Enable debug mode --threshold <float> Clustering threshold (default: 0.7045655, pyannote community-1)
--output <file> Save results to file instead of stdout --debug Enable debug mode
--output <file> Save results to file instead of stdout
--rttm <file> Compute offline DER/JER metrics against RTTM annotations
--export-embeddings <file> Export embeddings to JSON for debugging (offline mode only)
Example:
fluidaudio process audio.wav --threshold 0.5 --output results.json Examples:
# Streaming mode (default)
fluidaudio process audio.wav --output results.json
# Offline mode with VBx clustering (default threshold 0.7045655)
fluidaudio process audio.wav --mode offline --output results.json
# Offline mode with embedding export for debugging
fluidaudio process audio.wav --mode offline --export-embeddings embeddings.json
""" """
) )
} }
+7 -11
View File
@@ -11,13 +11,16 @@ struct ProcessingResult: Codable {
let realTimeFactor: Float let realTimeFactor: Float
let segments: [TimedSpeakerSegment] let segments: [TimedSpeakerSegment]
let speakerCount: Int let speakerCount: Int
let config: DiarizerConfig let config: DiarizerConfig?
let metrics: DiarizationMetrics?
let timings: PipelineTimings?
let timestamp: Date let timestamp: Date
init( init(
audioFile: String, durationSeconds: Float, processingTimeSeconds: TimeInterval, audioFile: String, durationSeconds: Float, processingTimeSeconds: TimeInterval,
realTimeFactor: Float, segments: [TimedSpeakerSegment], speakerCount: Int, realTimeFactor: Float, segments: [TimedSpeakerSegment], speakerCount: Int,
config: DiarizerConfig config: DiarizerConfig?, metrics: DiarizationMetrics? = nil,
timings: PipelineTimings? = nil
) { ) {
self.audioFile = audioFile self.audioFile = audioFile
self.durationSeconds = durationSeconds self.durationSeconds = durationSeconds
@@ -26,6 +29,8 @@ struct ProcessingResult: Codable {
self.segments = segments self.segments = segments
self.speakerCount = speakerCount self.speakerCount = speakerCount
self.config = config self.config = config
self.metrics = metrics
self.timings = timings
self.timestamp = Date() self.timestamp = Date()
} }
} }
@@ -243,13 +248,4 @@ struct VadBenchmarkResult {
let correctPredictions: Int let correctPredictions: Int
} }
struct DiarizationMetrics {
let der: Float
let jer: Float
let missRate: Float
let falseAlarmRate: Float
let speakerErrorRate: Float
let mappedSpeakerCount: Int // Number of predicted speakers that mapped to ground truth
}
#endif #endif
@@ -0,0 +1,713 @@
#if os(macOS)
import FluidAudio
import Foundation
/// Aggregate diarization quality metrics.
struct DiarizationMetrics: Codable {
let der: Float
let missRate: Float
let falseAlarmRate: Float
let speakerErrorRate: Float
let jer: Float
let speakerMapping: [String: String]
let evaluationCollarSeconds: Float
let evaluationIgnoresOverlap: Bool
private enum CodingKeys: String, CodingKey {
case der
case missRate
case falseAlarmRate
case speakerErrorRate
case jer
case speakerMapping
case evaluationCollarSeconds
case evaluationIgnoresOverlap
}
init(
der: Float,
missRate: Float,
falseAlarmRate: Float,
speakerErrorRate: Float,
jer: Float,
speakerMapping: [String: String],
evaluationCollarSeconds: Float,
evaluationIgnoresOverlap: Bool
) {
self.der = der
self.missRate = missRate
self.falseAlarmRate = falseAlarmRate
self.speakerErrorRate = speakerErrorRate
self.jer = jer
self.speakerMapping = speakerMapping
self.evaluationCollarSeconds = evaluationCollarSeconds
self.evaluationIgnoresOverlap = evaluationIgnoresOverlap
}
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
der = try container.decode(Float.self, forKey: .der)
missRate = try container.decode(Float.self, forKey: .missRate)
falseAlarmRate = try container.decode(Float.self, forKey: .falseAlarmRate)
speakerErrorRate = try container.decode(Float.self, forKey: .speakerErrorRate)
jer = try container.decode(Float.self, forKey: .jer)
speakerMapping = try container.decode([String: String].self, forKey: .speakerMapping)
evaluationCollarSeconds =
try container.decodeIfPresent(
Float.self,
forKey: .evaluationCollarSeconds
) ?? 0.25
evaluationIgnoresOverlap =
try container.decodeIfPresent(
Bool.self,
forKey: .evaluationIgnoresOverlap
) ?? true
}
}
/// Utility for computing diarization metrics that can be shared between CLI commands.
enum DiarizationMetricsCalculator {
private static let scoringCollarSeconds: Double = 0.25
private static let ignoreOverlap = true
private struct ScoringSegment {
let start: Double
let end: Double
let speaker: String
}
private enum EventType: Int {
case start = 0
case end = 1
}
private static func zeroMetrics() -> DiarizationMetrics {
DiarizationMetrics(
der: 0,
missRate: 0,
falseAlarmRate: 0,
speakerErrorRate: 0,
jer: 0,
speakerMapping: [:],
evaluationCollarSeconds: Float(scoringCollarSeconds),
evaluationIgnoresOverlap: ignoreOverlap
)
}
/// Compute offline diarization metrics using segment-level overlap analysis.
/// - Parameters:
/// - predicted: Predicted speaker segments.
/// - groundTruth: Ground-truth speaker segments.
/// - frameSize: Retained for compatibility; unused in segment-based evaluation.
/// - logger: Optional logger for emitting debug information.
/// - Returns: Aggregate metrics including DER, JER, and the speaker mapping.
static func offlineMetrics(
predicted: [TimedSpeakerSegment],
groundTruth: [TimedSpeakerSegment],
frameSize: Float = 0.01,
audioDurationSeconds: Double? = nil,
logger: AppLogger? = nil
) -> DiarizationMetrics {
_ = frameSize // Retained for API compatibility; no longer needed.
guard !groundTruth.isEmpty else {
return zeroMetrics()
}
let predictedSegments =
predicted
.map {
ScoringSegment(
start: Double($0.startTimeSeconds),
end: Double($0.endTimeSeconds),
speaker: $0.speakerId
)
}
.sorted { $0.start < $1.start }
let groundTruthSegments =
groundTruth
.map {
ScoringSegment(
start: Double($0.startTimeSeconds),
end: Double($0.endTimeSeconds),
speaker: $0.speakerId
)
}
.sorted { $0.start < $1.start }
let (processedGroundTruth, excludedIntervals) = applyOfficialReferenceProcessing(groundTruthSegments)
let maxPredictedEnd = predictedSegments.map { $0.end }.max() ?? 0
let maxGroundTruthEnd = groundTruthSegments.map { $0.end }.max() ?? 0
let evaluationDuration = max(audioDurationSeconds ?? 0, maxPredictedEnd, maxGroundTruthEnd)
guard evaluationDuration > 0 else {
return zeroMetrics()
}
let evaluationIntervals = subtractIntervals(
from: (0.0, evaluationDuration),
removing: excludedIntervals
)
guard !evaluationIntervals.isEmpty else {
return zeroMetrics()
}
let predictedEvaluated = clipSegments(predictedSegments, to: evaluationIntervals)
let processedGroundTruthEvaluated = clipSegments(processedGroundTruth, to: evaluationIntervals)
guard !processedGroundTruthEvaluated.isEmpty else {
return zeroMetrics()
}
let groundTruthIntervals = mergeIntervals(processedGroundTruthEvaluated.map { ($0.start, $0.end) })
let predictedIntervals = mergeIntervals(predictedEvaluated.map { ($0.start, $0.end) })
let referenceSpeech = groundTruthIntervals.reduce(0.0) { $0 + ($1.1 - $1.0) }
guard referenceSpeech > 0 else {
return zeroMetrics()
}
let predictedSpeech = predictedIntervals.reduce(0.0) { $0 + ($1.1 - $1.0) }
let overlapSpeech = intervalsOverlap(groundTruthIntervals, predictedIntervals)
let miss = max(0.0, referenceSpeech - overlapSpeech)
let falseAlarm = max(0.0, predictedSpeech - overlapSpeech)
let groundTruthBySpeaker = segmentsBySpeaker(processedGroundTruthEvaluated)
let predictedBySpeaker = segmentsBySpeaker(predictedEvaluated)
let speakerMapping = computeSpeakerMapping(
predicted: predictedBySpeaker,
groundTruth: groundTruthBySpeaker
)
var correctlyAssigned = 0.0
for (predId, truthId) in speakerMapping {
if let predSegments = predictedBySpeaker[predId],
let truthSegments = groundTruthBySpeaker[truthId]
{
correctlyAssigned += overlapDuration(predSegments, truthSegments)
}
}
let confusion = max(0.0, overlapSpeech - correctlyAssigned)
let missRate = Float((miss / referenceSpeech) * 100)
let falseAlarmRate = Float((falseAlarm / referenceSpeech) * 100)
let speakerErrorRate = Float((confusion / referenceSpeech) * 100)
let der = missRate + falseAlarmRate + speakerErrorRate
var jaccardScores: [Double] = []
let inverseMapping = Dictionary(uniqueKeysWithValues: speakerMapping.map { ($0.value, $0.key) })
for (truthId, truthSegments) in groundTruthBySpeaker {
let matchedPred = inverseMapping[truthId]
let predictedSegmentsForSpeaker = matchedPred.flatMap { predictedBySpeaker[$0] } ?? []
let intersection = overlapDuration(predictedSegmentsForSpeaker, truthSegments)
let union = unionDuration(predictedSegmentsForSpeaker, truthSegments)
if union > 0 {
jaccardScores.append(intersection / union)
}
}
for (predId, predSegments) in predictedBySpeaker where speakerMapping[predId] == nil {
if unionDuration(predSegments) > 0 {
jaccardScores.append(0.0)
}
}
let jer: Float
if jaccardScores.isEmpty {
jer = 0
} else {
let averageJaccard = jaccardScores.reduce(0.0, +) / Double(jaccardScores.count)
jer = Float((1.0 - averageJaccard) * 100)
}
if let logger = logger {
logger.debug("🎯 Offline mapping: \(speakerMapping)")
let formattedDer = String(format: "%.1f", der)
let formattedMiss = String(format: "%.1f", missRate)
let formattedFalseAlarm = String(format: "%.1f", falseAlarmRate)
let formattedSpeakerError = String(format: "%.1f", speakerErrorRate)
let formattedJer = String(format: "%.1f", jer)
let formattedCollar = String(format: "%.2f", scoringCollarSeconds)
let summary =
"📊 OFFLINE METRICS: DER=\(formattedDer)% "
+ "(Miss=\(formattedMiss)%, FA=\(formattedFalseAlarm)%, "
+ "SE=\(formattedSpeakerError)%, JER=\(formattedJer)%) "
+ "(collar=\(formattedCollar)s, ignoreOverlap=\(ignoreOverlap))"
logger.info(summary)
}
return DiarizationMetrics(
der: der,
missRate: missRate,
falseAlarmRate: falseAlarmRate,
speakerErrorRate: speakerErrorRate,
jer: jer,
speakerMapping: speakerMapping,
evaluationCollarSeconds: Float(scoringCollarSeconds),
evaluationIgnoresOverlap: ignoreOverlap
)
}
// MARK: - Segment-level helpers
private static func applyOfficialReferenceProcessing(
_ segments: [ScoringSegment]
) -> ([ScoringSegment], [(Double, Double)]) {
guard !segments.isEmpty else { return ([], []) }
var trimmed: [ScoringSegment] = []
var excluded: [(Double, Double)] = []
for segment in segments {
let trimmedStart = segment.start + scoringCollarSeconds
let trimmedEnd = segment.end - scoringCollarSeconds
if trimmedEnd <= trimmedStart {
excluded.append((segment.start, segment.end))
continue
}
if trimmedStart > segment.start {
excluded.append((segment.start, trimmedStart))
}
if trimmedEnd < segment.end {
excluded.append((trimmedEnd, segment.end))
}
trimmed.append(
ScoringSegment(start: trimmedStart, end: trimmedEnd, speaker: segment.speaker)
)
}
if trimmed.isEmpty {
return ([], mergeIntervals(excluded))
}
var processed = trimmed
if ignoreOverlap {
let isolated = isolateSingleSpeakerSegments(processed)
processed = isolated.segments
excluded.append(contentsOf: isolated.excluded)
} else {
processed = mergeAdjacentSegments(processed)
}
processed.removeAll { $0.end <= $0.start }
return (processed, mergeIntervals(excluded))
}
private static func mergeAdjacentSegments(
_ segments: [ScoringSegment]
) -> [ScoringSegment] {
var merged: [ScoringSegment] = []
for segment in segments {
guard segment.end > segment.start else { continue }
if let last = merged.last, last.speaker == segment.speaker, segment.start <= last.end {
let updated = ScoringSegment(
start: last.start,
end: max(last.end, segment.end),
speaker: last.speaker
)
merged[merged.count - 1] = updated
} else if let last = merged.last,
last.speaker == segment.speaker,
abs(segment.start - last.end) < 1e-9
{
let updated = ScoringSegment(
start: last.start,
end: max(last.end, segment.end),
speaker: last.speaker
)
merged[merged.count - 1] = updated
} else {
merged.append(segment)
}
}
return merged
}
private static func isolateSingleSpeakerSegments(
_ segments: [ScoringSegment]
) -> (segments: [ScoringSegment], excluded: [(Double, Double)]) {
struct Event {
let time: Double
let type: EventType
let speaker: String
}
var events: [Event] = []
for segment in segments where segment.end > segment.start {
events.append(Event(time: segment.start, type: .start, speaker: segment.speaker))
events.append(Event(time: segment.end, type: .end, speaker: segment.speaker))
}
guard !events.isEmpty else { return ([], []) }
events.sort { lhs, rhs in
if lhs.time == rhs.time {
return lhs.type.rawValue < rhs.type.rawValue
}
return lhs.time < rhs.time
}
var activeCounts: [String: Int] = [:]
var singleSpeaker: [ScoringSegment] = []
var excluded: [(Double, Double)] = []
var index = 0
var previousTime: Double?
while index < events.count {
let currentTime = events[index].time
if let prev = previousTime, currentTime > prev {
let activeSpeakers = activeCounts.filter { $0.value > 0 }.map(\.key)
if activeSpeakers.count == 1, let speaker = activeSpeakers.first {
singleSpeaker.append(
ScoringSegment(start: prev, end: currentTime, speaker: speaker)
)
} else if activeSpeakers.count > 1 {
excluded.append((prev, currentTime))
}
}
while index < events.count && events[index].time == currentTime {
let event = events[index]
switch event.type {
case .start:
activeCounts[event.speaker, default: 0] += 1
case .end:
let count = activeCounts[event.speaker] ?? 0
if count <= 1 {
activeCounts.removeValue(forKey: event.speaker)
} else {
activeCounts[event.speaker] = count - 1
}
}
index += 1
}
previousTime = currentTime
}
return (mergeAdjacentSegments(singleSpeaker), mergeIntervals(excluded))
}
private static func clipSegments(
_ segments: [ScoringSegment],
to intervals: [(Double, Double)]
) -> [ScoringSegment] {
guard !segments.isEmpty, !intervals.isEmpty else { return [] }
let sortedSegments = segments.sorted { $0.start < $1.start }
let sortedIntervals = intervals.sorted { $0.0 < $1.0 }
var clipped: [ScoringSegment] = []
var intervalIndex = 0
for segment in sortedSegments where segment.end > segment.start {
while intervalIndex < sortedIntervals.count && sortedIntervals[intervalIndex].1 <= segment.start {
intervalIndex += 1
}
var probeIndex = intervalIndex
while probeIndex < sortedIntervals.count {
let interval = sortedIntervals[probeIndex]
if interval.0 >= segment.end {
break
}
let overlapStart = max(segment.start, interval.0)
let overlapEnd = min(segment.end, interval.1)
if overlapEnd > overlapStart {
clipped.append(
ScoringSegment(start: overlapStart, end: overlapEnd, speaker: segment.speaker)
)
}
if interval.1 >= segment.end {
break
}
probeIndex += 1
}
}
return clipped
}
private static func subtractIntervals(
from span: (Double, Double),
removing intervals: [(Double, Double)]
) -> [(Double, Double)] {
let (start, end) = span
guard end > start else { return [] }
let merged = mergeIntervals(
intervals.map { (max(start, $0.0), min(end, $0.1)) }
.filter { $0.1 > start && $0.0 < end }
)
var remaining: [(Double, Double)] = []
var cursor = start
for interval in merged {
if interval.0 > cursor {
remaining.append((cursor, min(interval.0, end)))
}
cursor = max(cursor, interval.1)
if cursor >= end {
break
}
}
if cursor < end {
remaining.append((cursor, end))
}
return remaining
}
private static func mergeIntervals(
_ intervals: [(Double, Double)]
) -> [(Double, Double)] {
guard !intervals.isEmpty else { return [] }
let sorted = intervals.sorted { lhs, rhs in
if lhs.0 == rhs.0 {
return lhs.1 < rhs.1
}
return lhs.0 < rhs.0
}
var merged: [(Double, Double)] = []
var current = sorted[0]
for interval in sorted.dropFirst() {
if interval.0 <= current.1 {
current.1 = max(current.1, interval.1)
} else {
merged.append(current)
current = interval
}
}
merged.append(current)
return merged
}
private static func intervalsOverlap(
_ lhs: [(Double, Double)],
_ rhs: [(Double, Double)]
) -> Double {
var total = 0.0
var i = 0
var j = 0
while i < lhs.count && j < rhs.count {
let a = lhs[i]
let b = rhs[j]
let start = max(a.0, b.0)
let end = min(a.1, b.1)
if end > start {
total += end - start
}
if a.1 <= b.1 {
i += 1
} else {
j += 1
}
}
return total
}
private static func segmentsBySpeaker(
_ segments: [ScoringSegment]
) -> [String: [ScoringSegment]] {
var grouped: [String: [ScoringSegment]] = [:]
for segment in segments {
grouped[segment.speaker, default: []].append(segment)
}
for key in grouped.keys {
grouped[key]?.sort { $0.start < $1.start }
}
return grouped
}
private static func overlapDuration(
_ lhs: [ScoringSegment],
_ rhs: [ScoringSegment]
) -> Double {
var total = 0.0
var i = 0
var j = 0
while i < lhs.count && j < rhs.count {
let a = lhs[i]
let b = rhs[j]
let start = max(a.start, b.start)
let end = min(a.end, b.end)
if end > start {
total += end - start
}
if a.end <= b.end {
i += 1
} else {
j += 1
}
}
return total
}
private static func unionDuration(
_ segments: [ScoringSegment]
) -> Double {
let intervals = segments.map { ($0.start, $0.end) }
return mergeIntervals(intervals).reduce(0.0) { $0 + ($1.1 - $1.0) }
}
private static func unionDuration(
_ lhs: [ScoringSegment],
_ rhs: [ScoringSegment]
) -> Double {
return unionDuration(lhs + rhs)
}
private static func computeSpeakerMapping(
predicted: [String: [ScoringSegment]],
groundTruth: [String: [ScoringSegment]]
) -> [String: String] {
guard !predicted.isEmpty, !groundTruth.isEmpty else { return [:] }
let predictedIds = Array(predicted.keys)
let groundTruthIds = Array(groundTruth.keys)
var confusionMatrix = Array(
repeating: Array(repeating: 0, count: predictedIds.count),
count: groundTruthIds.count
)
let scale = 1_000.0
for (gtIndex, gtId) in groundTruthIds.enumerated() {
for (predIndex, predId) in predictedIds.enumerated() {
let overlap = overlapDuration(
predicted[predId] ?? [],
groundTruth[gtId] ?? []
)
confusionMatrix[gtIndex][predIndex] = Int((overlap * scale).rounded())
}
}
let assignment = AssignmentSolver.bestAssignment(confusionMatrix: confusionMatrix)
var mapping: [String: String] = [:]
for (predIndex, gtIndex) in assignment {
guard predIndex < predictedIds.count, gtIndex < groundTruthIds.count else { continue }
if confusionMatrix[gtIndex][predIndex] > 0 {
mapping[predictedIds[predIndex]] = groundTruthIds[gtIndex]
}
}
return mapping
}
// MARK: - Assignment solver (DP over subsets)
private enum AssignmentSolver {
struct Key: Hashable {
let predIndex: Int
let mask: Int
}
struct Result {
let score: Int
let mapping: [Int: Int]
}
static func bestAssignment(confusionMatrix: [[Int]]) -> [Int: Int] {
let gtCount = confusionMatrix.count
guard gtCount > 0 else { return [:] }
let predCount = confusionMatrix.first?.count ?? 0
guard predCount > 0 else { return [:] }
if gtCount >= Int.bitWidth {
return greedyAssignment(confusionMatrix: confusionMatrix)
}
var memo: [Key: Result] = [:]
func dfs(predIndex: Int, mask: Int) -> Result {
if predIndex == predCount {
return Result(score: 0, mapping: [:])
}
let key = Key(predIndex: predIndex, mask: mask)
if let cached = memo[key] {
return cached
}
var bestResult = dfs(predIndex: predIndex + 1, mask: mask)
for gtIndex in 0..<gtCount where (mask & (1 << gtIndex)) == 0 {
let nextResult = dfs(predIndex: predIndex + 1, mask: mask | (1 << gtIndex))
let candidateScore = nextResult.score + confusionMatrix[gtIndex][predIndex]
if candidateScore > bestResult.score {
var updatedMapping = nextResult.mapping
updatedMapping[predIndex] = gtIndex
bestResult = Result(score: candidateScore, mapping: updatedMapping)
}
}
memo[key] = bestResult
return bestResult
}
return dfs(predIndex: 0, mask: 0).mapping
}
private static func greedyAssignment(confusionMatrix: [[Int]]) -> [Int: Int] {
let gtCount = confusionMatrix.count
let predCount = confusionMatrix.first?.count ?? 0
var assignments: [Int: Int] = [:]
var usedGroundTruth = Set<Int>()
for predIndex in 0..<predCount {
var bestGt = -1
var bestScore = Int.min
for gtIndex in 0..<gtCount where !usedGroundTruth.contains(gtIndex) {
let score = confusionMatrix[gtIndex][predIndex]
if score > bestScore {
bestScore = score
bestGt = gtIndex
}
}
if bestGt >= 0 {
assignments[predIndex] = bestGt
usedGroundTruth.insert(bestGt)
}
}
return assignments
}
}
}
#endif
@@ -0,0 +1,65 @@
#if os(macOS)
import FluidAudio
import Foundation
enum RTTMParserError: Error, LocalizedError {
case fileNotFound(String)
case invalidLine(String)
var errorDescription: String? {
switch self {
case .fileNotFound(let path):
return "RTTM file not found at \(path)"
case .invalidLine(let line):
return "Invalid RTTM line: \(line)"
}
}
}
/// Lightweight RTTM parser for converting ground-truth annotations into `TimedSpeakerSegment`s.
enum RTTMParser {
static func loadSegments(from path: String) throws -> [TimedSpeakerSegment] {
guard FileManager.default.fileExists(atPath: path) else {
throw RTTMParserError.fileNotFound(path)
}
let contents = try String(contentsOfFile: path, encoding: .utf8)
var segments: [TimedSpeakerSegment] = []
for rawLine in contents.components(separatedBy: .newlines) {
let line = rawLine.trimmingCharacters(in: .whitespaces)
if line.isEmpty || line.hasPrefix("#") {
continue
}
let fields = line.split(whereSeparator: { $0.isWhitespace })
guard fields.count >= 8, fields[0] == "SPEAKER" else {
throw RTTMParserError.invalidLine(line)
}
guard
let start = Float(fields[3]),
let duration = Float(fields[4])
else {
throw RTTMParserError.invalidLine(line)
}
let speakerId = String(fields[7])
let endTime = start + duration
segments.append(
TimedSpeakerSegment(
speakerId: speakerId,
embedding: [],
startTimeSeconds: start,
endTimeSeconds: endTime,
qualityScore: 1.0
)
)
}
return segments.sorted { $0.startTimeSeconds < $1.startTimeSeconds }
}
}
#endif
@@ -6,23 +6,31 @@ import Foundation
struct ResultsFormatter { struct ResultsFormatter {
static func printResults(_ result: ProcessingResult) async { static func printResults(_ result: ProcessingResult) async {
print("📊 Diarization Results:") print("Diarization Results:")
print(" Audio File: \(result.audioFile)") print("Audio File: \(result.audioFile)")
print(" Duration: \(String(format: "%.1f", result.durationSeconds))s") print("Duration: \(String(format: "%.1f", result.durationSeconds))s")
print(" Processing Time: \(String(format: "%.1f", result.processingTimeSeconds))s") print("Processing Time: \(String(format: "%.1f", result.processingTimeSeconds))s")
let rtfx = result.realTimeFactor let rtfx = result.realTimeFactor
print(" Speed Factor (RTFx): \(String(format: "%.2f", rtfx))x") print("Speed Factor (RTFx): \(String(format: "%.2f", rtfx))x")
print(" Detected Speakers: \(result.speakerCount)") print("Detected Speakers: \(result.speakerCount)")
print("🎤 Speaker Segments:") if let metrics = result.metrics {
for (index, segment) in result.segments.enumerated() {
let startTime = formatTime(segment.startTimeSeconds)
let endTime = formatTime(segment.endTimeSeconds)
let duration = segment.endTimeSeconds - segment.startTimeSeconds
print( print(
" \(index + 1). \(segment.speakerId): \(startTime) - \(endTime) (\(String(format: "%.1f", duration))s)" "DER: \(String(format: "%.1f", metrics.der))% (Miss="
+ "\(String(format: "%.1f", metrics.missRate))%, FA="
+ "\(String(format: "%.1f", metrics.falseAlarmRate))%, SE="
+ "\(String(format: "%.1f", metrics.speakerErrorRate))%, JER="
+ "\(String(format: "%.1f", metrics.jer))%)"
) )
if !metrics.speakerMapping.isEmpty {
print("Speaker Mapping:")
for (pred, truth) in metrics.speakerMapping.sorted(by: { $0.key < $1.key }) {
print(" \(pred)\(truth)")
}
}
}
if let timings = result.timings {
print("")
printSingleRunTimings(timings, durationSeconds: result.durationSeconds)
} }
} }
@@ -35,6 +43,48 @@ struct ResultsFormatter {
try data.write(to: URL(fileURLWithPath: file)) try data.write(to: URL(fileURLWithPath: file))
} }
private static func printSingleRunTimings(
_ timings: PipelineTimings,
durationSeconds: Float
) {
let stages: [(String, TimeInterval)] = [
("Model Compilation", timings.modelCompilationSeconds),
("Audio Loading", timings.audioLoadingSeconds),
("Segmentation", timings.segmentationSeconds),
("Embedding Extraction", timings.embeddingExtractionSeconds),
("Speaker Clustering", timings.speakerClusteringSeconds),
("Post Processing", timings.postProcessingSeconds),
]
let total = timings.totalProcessingSeconds
let totalAudioMinutes = Double(durationSeconds) / 60.0
print("⏱️ Pipeline Timing Breakdown")
let separator = String(repeating: "=", count: 95)
print(separator)
print("│ Stage │ Time │ Percentage │ Per Audio Minute │")
let headerSeparator = "├───────────────────────┼──────────┼────────────┼──────────────────┤"
print(headerSeparator)
for (stageName, stageTime) in stages {
let stageNamePadded = stageName.padding(toLength: 19, withPad: " ", startingAt: 0)
let timeStr = String(format: "%.3fs", stageTime).padding(
toLength: 8, withPad: " ", startingAt: 0)
let percentage = total > 0 ? (stageTime / total) * 100 : 0
let percentageStr = String(format: "%.1f%%", percentage).padding(
toLength: 10, withPad: " ", startingAt: 0)
let perMinute = totalAudioMinutes > 0 ? stageTime / totalAudioMinutes : 0
let perMinuteStr = String(format: "%.3fs/min", perMinute).padding(
toLength: 16, withPad: " ", startingAt: 0)
print("\(stageNamePadded)\(timeStr)\(percentageStr)\(perMinuteStr)")
}
let footerSeparator = "└───────────────────────┴──────────┴────────────┴──────────────────┘"
print(footerSeparator)
let totalText = String(format: "%.3f", total)
print("Bottleneck Stage: \(timings.bottleneckStage) • Total Processing: \(totalText)s")
}
static func saveBenchmarkResults(_ summary: BenchmarkSummary, to file: String) async throws { static func saveBenchmarkResults(_ summary: BenchmarkSummary, to file: String) async throws {
let encoder = JSONEncoder() let encoder = JSONEncoder()
encoder.outputFormatting = [.prettyPrinted, .sortedKeys] encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
@@ -210,7 +260,6 @@ struct ResultsFormatter {
// Print each stage // Print each stage
let stages: [(String, TimeInterval)] = [ let stages: [(String, TimeInterval)] = [
("Model Download", avgTimings.modelDownloadSeconds),
("Model Compilation", avgTimings.modelCompilationSeconds), ("Model Compilation", avgTimings.modelCompilationSeconds),
("Audio Loading", avgTimings.audioLoadingSeconds), ("Audio Loading", avgTimings.audioLoadingSeconds),
("Segmentation", avgTimings.segmentationSeconds), ("Segmentation", avgTimings.segmentationSeconds),
@@ -256,15 +305,11 @@ struct ResultsFormatter {
" Inference Only: \(String(format: "%.3f", avgTimings.totalInferenceSeconds))s (\(String(format: "%.1f", (avgTimings.totalInferenceSeconds / totalAvgTime) * 100))% of total)" " Inference Only: \(String(format: "%.3f", avgTimings.totalInferenceSeconds))s (\(String(format: "%.1f", (avgTimings.totalInferenceSeconds / totalAvgTime) * 100))% of total)"
) )
print( print(
" Setup Overhead: \(String(format: "%.3f", avgTimings.modelDownloadSeconds + avgTimings.modelCompilationSeconds))s (\(String(format: "%.1f", ((avgTimings.modelDownloadSeconds + avgTimings.modelCompilationSeconds) / totalAvgTime) * 100))% of total)" " Setup Overhead: \(String(format: "%.3f", avgTimings.modelCompilationSeconds))s (\(String(format: "%.1f", (avgTimings.modelCompilationSeconds / totalAvgTime) * 100))% of total)"
) )
// Optimization suggestions // Optimization suggestions
if avgTimings.modelDownloadSeconds > avgTimings.totalInferenceSeconds { if avgTimings.segmentationSeconds > avgTimings.embeddingExtractionSeconds * 2 {
print(
"💡 Optimization Suggestion: Model download is dominating execution time - consider model caching"
)
} else if avgTimings.segmentationSeconds > avgTimings.embeddingExtractionSeconds * 2 {
print( print(
"💡 Optimization Suggestion: Segmentation is the bottleneck - consider model optimization" "💡 Optimization Suggestion: Segmentation is the bottleneck - consider model optimization"
) )
@@ -280,7 +325,6 @@ struct ResultsFormatter {
let count = Double(results.count) let count = Double(results.count)
guard count > 0 else { return PipelineTimings() } guard count > 0 else { return PipelineTimings() }
let avgModelDownload = results.reduce(0.0) { $0 + $1.timings.modelDownloadSeconds } / count
let avgModelCompilation = let avgModelCompilation =
results.reduce(0.0) { $0 + $1.timings.modelCompilationSeconds } / count results.reduce(0.0) { $0 + $1.timings.modelCompilationSeconds } / count
let avgAudioLoading = results.reduce(0.0) { $0 + $1.timings.audioLoadingSeconds } / count let avgAudioLoading = results.reduce(0.0) { $0 + $1.timings.audioLoadingSeconds } / count
@@ -292,7 +336,6 @@ struct ResultsFormatter {
results.reduce(0.0) { $0 + $1.timings.postProcessingSeconds } / count results.reduce(0.0) { $0 + $1.timings.postProcessingSeconds } / count
return PipelineTimings( return PipelineTimings(
modelDownloadSeconds: avgModelDownload,
modelCompilationSeconds: avgModelCompilation, modelCompilationSeconds: avgModelCompilation,
audioLoadingSeconds: avgAudioLoading, audioLoadingSeconds: avgAudioLoading,
segmentationSeconds: avgSegmentation, segmentationSeconds: avgSegmentation,
@@ -0,0 +1,346 @@
import Accelerate
import CoreML
import XCTest
@testable import FluidAudio
final class WeightInterpolationTests: XCTestCase {
func testResampleUsesHalfPixelOffsetMapping() {
let input: [Float] = [0, 10, 20, 30]
let result = WeightInterpolation.resample(input, to: 2)
XCTAssertEqual(result.count, 2)
XCTAssertEqual(result[0], 5, accuracy: 1e-5)
XCTAssertEqual(result[1], 25, accuracy: 1e-5)
}
func testResampleMatchesInterpolationCoefficients() {
let input = (0..<16).map { Float($0) * 0.25 }
let outputLength = 7
let direct = WeightInterpolation.resample(input, to: outputLength)
let coefficients = WeightInterpolation.InterpolationCoefficients(
inputLength: input.count,
outputLength: outputLength
)
let gathered = coefficients.interpolate(input)
XCTAssertEqual(direct.count, gathered.count)
for (lhs, rhs) in zip(direct, gathered) {
XCTAssertEqual(lhs, rhs, accuracy: 1e-5)
}
}
func testResample2DBroadcastsRows() {
let inputs: [[Float]] = [
[1, 3, 5, 7],
[2, 4, 6, 8],
]
let outputs = WeightInterpolation.resample2D(inputs, to: 2)
XCTAssertEqual(outputs.count, 2)
XCTAssertEqual(outputs[0][0], 2, accuracy: 1e-5)
XCTAssertEqual(outputs[0][1], 6, accuracy: 1e-5)
XCTAssertEqual(outputs[1][0], 3, accuracy: 1e-5)
XCTAssertEqual(outputs[1][1], 7, accuracy: 1e-5)
}
func testZoomFactorProducesExpectedLength() {
let input = (0..<10).map(Float.init)
let zoomed = WeightInterpolation.zoom(input, factor: 0.5)
XCTAssertEqual(zoomed.count, 5)
}
}
@available(macOS 13.0, iOS 16.0, *)
final class VDSPOperationsTests: XCTestCase {
func testL2NormalizeProducesUnitVector() {
let input: [Float] = [3, 4]
let normalized = VDSPOperations.l2Normalize(input)
XCTAssertEqual(normalized[0], 0.6, accuracy: 1e-6)
XCTAssertEqual(normalized[1], 0.8, accuracy: 1e-6)
XCTAssertEqual(VDSPOperations.dotProduct(normalized, normalized), 1, accuracy: 1e-5)
}
func testMatrixVectorMultiplyMatchesManualComputation() {
let matrix: [[Float]] = [
[1, 2, 3],
[4, 5, 6],
]
let vector: [Float] = [7, 8, 9]
let result = VDSPOperations.matrixVectorMultiply(matrix: matrix, vector: vector)
XCTAssertEqual(result.count, 2)
XCTAssertEqual(result[0], 50, accuracy: 1e-6)
XCTAssertEqual(result[1], 122, accuracy: 1e-6)
}
func testMatrixMultiplyMatchesExpected() {
let a: [[Float]] = [
[1, 2],
[3, 4],
]
let b: [[Float]] = [
[5, 6, 7],
[8, 9, 10],
]
let result = VDSPOperations.matrixMultiply(a: a, b: b)
XCTAssertEqual(result.count, 2)
XCTAssertEqual(result[0].count, 3)
XCTAssertEqual(result[1].count, 3)
XCTAssertEqual(result[0][0], 21, accuracy: 1e-6)
XCTAssertEqual(result[0][1], 24, accuracy: 1e-6)
XCTAssertEqual(result[0][2], 27, accuracy: 1e-6)
XCTAssertEqual(result[1][0], 47, accuracy: 1e-6)
XCTAssertEqual(result[1][1], 54, accuracy: 1e-6)
XCTAssertEqual(result[1][2], 61, accuracy: 1e-6)
}
func testLogSumExpMatchesAnalyticalValue() {
let vector: [Float] = [0.0, 1.0, 2.0]
let expected = log(exp(0.0) + exp(1.0) + exp(2.0))
XCTAssertEqual(VDSPOperations.logSumExp(vector), Float(expected), accuracy: 1e-5)
}
func testSoftmaxProducesProbabilityDistribution() {
let vector: [Float] = [1.0, 2.0, 3.0]
let result = VDSPOperations.softmax(vector)
let sum = result.reduce(0, +)
XCTAssertEqual(sum, 1.0, accuracy: 1e-5)
XCTAssertTrue(result[2] > result[1] && result[1] > result[0])
}
func testPairwiseEuclideanDistances() {
let a: [[Float]] = [
[0, 0],
[1, 1],
]
let b: [[Float]] = [
[0, 1],
[2, 3],
]
let distances = VDSPOperations.pairwiseEuclideanDistances(a: a, b: b)
XCTAssertEqual(distances.count, 2)
// a[0]=[0,0] vs b[0]=[0,1]: sqrt(0^2 + 1^2) = 1
XCTAssertEqual(distances[0][0], 1, accuracy: 1e-6)
// a[0]=[0,0] vs b[1]=[2,3]: sqrt(4 + 9) = sqrt(13)
XCTAssertEqual(distances[0][1], Float(sqrt(13)), accuracy: 1e-6)
// a[1]=[1,1] vs b[0]=[0,1]: sqrt(1 + 0) = 1
XCTAssertEqual(distances[1][0], 1, accuracy: 1e-6)
// a[1]=[1,1] vs b[1]=[2,3]: sqrt(1 + 4) = sqrt(5)
XCTAssertEqual(distances[1][1], Float(sqrt(5)), accuracy: 1e-6)
}
}
@available(macOS 13.0, iOS 16.0, *)
final class OfflineDiarizerConfigTests: XCTestCase {
func testDefaultConfigurationMatchesExpectedValues() throws {
let config = OfflineDiarizerConfig.default
XCTAssertEqual(config.clusteringThreshold, 0.6, accuracy: 1e-12)
XCTAssertEqual(config.Fa, 0.07)
XCTAssertEqual(config.Fb, 0.8)
XCTAssertEqual(config.maxVBxIterations, 20)
XCTAssertTrue(config.embeddingExcludeOverlap)
XCTAssertEqual(config.samplesPerWindow, 160_000)
XCTAssertNoThrow(try config.validate())
}
func testValidateThrowsForInvalidClusteringThreshold() {
let config = OfflineDiarizerConfig(clusteringThreshold: 1.5)
XCTAssertThrowsError(try config.validate()) { error in
guard case OfflineDiarizationError.invalidConfiguration(let message) = error else {
XCTFail("Expected invalidConfiguration, got \(error)")
return
}
XCTAssertTrue(message.contains("clustering.threshold"))
}
}
func testValidateThrowsForInvalidBatchSize() {
let config = OfflineDiarizerConfig(embeddingBatchSize: 0)
XCTAssertThrowsError(try config.validate()) { error in
guard case OfflineDiarizationError.invalidBatchSize(let message) = error else {
XCTFail("Expected invalidBatchSize, got \(error)")
return
}
XCTAssertTrue(message.contains("embeddingBatchSize"))
}
}
func testValidateThrowsForInvalidSegmentationMinDurationOn() {
var config = OfflineDiarizerConfig()
config.segmentationMinDurationOn = -0.5
XCTAssertThrowsError(try config.validate()) { error in
guard case OfflineDiarizationError.invalidConfiguration(let message) = error else {
XCTFail("Expected invalidConfiguration, got \(error)")
return
}
XCTAssertTrue(message.contains("segmentation.minDurationOn"))
}
}
}
@available(macOS 13.0, iOS 16.0, *)
final class OfflineTypesTests: XCTestCase {
func testErrorDescriptionsAreHumanReadable() {
XCTAssertEqual(
OfflineDiarizationError.modelNotLoaded("segmentation").localizedDescription,
"Model not loaded: segmentation"
)
XCTAssertEqual(
OfflineDiarizationError.noSpeechDetected.localizedDescription,
"No speech detected in audio"
)
XCTAssertEqual(
OfflineDiarizationError.invalidBatchSize("embedding batch").localizedDescription,
"Invalid batch size: embedding batch"
)
}
func testSegmentationOutputInitialization() {
let output = SegmentationOutput(
logProbs: [[[0.1, 0.9]]],
numChunks: 1,
numFrames: 1,
numSpeakers: 2
)
XCTAssertEqual(output.numChunks, 1)
XCTAssertEqual(output.numFrames, 1)
XCTAssertEqual(output.numSpeakers, 2)
}
func testVBxOutputInitialization() {
let output = VBxOutput(
gamma: [[0.6, 0.4]],
pi: [0.5, 0.5],
hardClusters: [[0, 1]],
centroids: [[0.1, 0.2], [0.3, 0.4]],
numClusters: 2,
elbos: [1.0, 1.1]
)
XCTAssertEqual(output.gamma.count, 1)
XCTAssertEqual(output.numClusters, 2)
XCTAssertEqual(output.centroids[1][1], 0.4, accuracy: 1e-6)
}
}
@available(macOS 13.0, iOS 16.0, *)
final class ModelWarmupTests: XCTestCase {
func testWarmupSingleInputInvokesPredictionsWithExpectedShape() throws {
let model = WarmupMockModel()
let iterations = 3
let duration = try ModelWarmup.warmup(
model: model,
inputName: "audio",
inputShape: [1, 160],
iterations: iterations
)
XCTAssertGreaterThanOrEqual(duration, 0)
XCTAssertEqual(model.receivedInputs.count, iterations)
for invocation in model.receivedInputs {
let array = invocation["audio"]
XCTAssertNotNil(array)
XCTAssertEqual(array?.shape.map { $0.intValue }, [1, 160])
}
}
func testWarmupEmbeddingModelUsesFbankInputsWhenAvailable() throws {
let model = WarmupMockModel()
let weightFrames = 64
try ModelWarmup.warmupEmbeddingModel(model, weightFrames: weightFrames)
guard let lastInvocation = model.receivedInputs.last else {
XCTFail("Expected at least one invocation")
return
}
let features = lastInvocation["fbank_features"]
let weights = lastInvocation["weights"]
XCTAssertNotNil(features)
XCTAssertNotNil(weights)
XCTAssertEqual(features?.shape.map { $0.intValue }, [1, 1, 80, 998])
XCTAssertEqual(weights?.shape.map { $0.intValue }, [1, weightFrames])
}
func testWarmupEmbeddingModelFallsBackToCombinedWhenFbankFails() throws {
let model = WarmupMockModel()
model.failureKeys = ["fbank_features"]
let weightFrames = 32
try ModelWarmup.warmupEmbeddingModel(model, weightFrames: weightFrames)
// Expect one invocation: only the successful combined fallback is recorded
XCTAssertEqual(model.receivedInputs.count, 1)
guard let lastInvocation = model.receivedInputs.last else {
XCTFail("Expected fallback invocation")
return
}
XCTAssertNotNil(lastInvocation["audio_and_weights"])
XCTAssertNil(lastInvocation["fbank_features"])
}
// MARK: - Helpers
private final class WarmupMockModel: MLModel {
private(set) var receivedInputs: [[String: MLMultiArray]] = []
var failureKeys: Set<String> = []
override func prediction(
from input: MLFeatureProvider,
options: MLPredictionOptions = MLPredictionOptions()
) throws -> MLFeatureProvider {
for name in input.featureNames {
if failureKeys.contains(name) {
throw MockError.simulatedFailure
}
}
var captured: [String: MLMultiArray] = [:]
for name in input.featureNames {
if let array = input.featureValue(for: name)?.multiArrayValue {
captured[name] = array
}
}
receivedInputs.append(captured)
return try MLDictionaryFeatureProvider(dictionary: [
"output": MLFeatureValue(double: 0.0)
])
}
private enum MockError: Error {
case simulatedFailure
}
}
}
@@ -153,7 +153,8 @@ final class SpeakerManagerTests: XCTestCase {
let info = manager.getSpeaker(for: id) let info = manager.getSpeaker(for: id)
XCTAssertNotNil(info) XCTAssertNotNil(info)
XCTAssertEqual(info?.id, id) XCTAssertEqual(info?.id, id)
XCTAssertEqual(info?.currentEmbedding, embedding) let normalizedExpected = VDSPOperations.l2Normalize(embedding)
XCTAssertEqual(info?.currentEmbedding, normalizedExpected)
XCTAssertEqual(info?.duration, 3.5) XCTAssertEqual(info?.duration, 3.5)
} }
} }
@@ -176,7 +177,8 @@ final class SpeakerManagerTests: XCTestCase {
// Verify the values // Verify the values
XCTAssertEqual(publicId, id) XCTAssertEqual(publicId, id)
XCTAssertEqual(publicEmbedding, embedding) let normalizedExpected = VDSPOperations.l2Normalize(embedding)
XCTAssertEqual(publicEmbedding, normalizedExpected)
XCTAssertEqual(publicDuration, 5.0) XCTAssertEqual(publicDuration, 5.0)
XCTAssertNotNil(publicUpdatedAt) XCTAssertNotNil(publicUpdatedAt)
XCTAssertEqual(publicUpdateCount, 1) XCTAssertEqual(publicUpdateCount, 1)
@@ -350,7 +352,8 @@ final class SpeakerManagerTests: XCTestCase {
let info = manager.getSpeaker(for: "TestSpeaker1") let info = manager.getSpeaker(for: "TestSpeaker1")
XCTAssertNotNil(info) XCTAssertNotNil(info)
XCTAssertEqual(info?.id, "TestSpeaker1") XCTAssertEqual(info?.id, "TestSpeaker1")
XCTAssertEqual(info?.currentEmbedding, embedding) let normalizedExpected = VDSPOperations.l2Normalize(embedding)
XCTAssertEqual(info?.currentEmbedding, normalizedExpected)
XCTAssertEqual(info?.duration, 5.0) XCTAssertEqual(info?.duration, 5.0)
XCTAssertEqual(info?.updateCount, 1) XCTAssertEqual(info?.updateCount, 1)
} }
@@ -417,7 +420,8 @@ final class SpeakerManagerTests: XCTestCase {
let info = manager.getSpeaker(for: "Alice") let info = manager.getSpeaker(for: "Alice")
XCTAssertNotNil(info) XCTAssertNotNil(info)
XCTAssertEqual(info?.id, "Alice") XCTAssertEqual(info?.id, "Alice")
XCTAssertEqual(info?.currentEmbedding, embedding) let normalizedExpected = VDSPOperations.l2Normalize(embedding)
XCTAssertEqual(info?.currentEmbedding, normalizedExpected)
XCTAssertEqual(info?.duration, 7.5) XCTAssertEqual(info?.duration, 7.5)
XCTAssertEqual(info?.rawEmbeddings.count, 1) XCTAssertEqual(info?.rawEmbeddings.count, 1)
} }
@@ -215,7 +215,14 @@ final class SpeakerOperationsTests: XCTestCase {
XCTAssertNotNil(speaker) XCTAssertNotNil(speaker)
XCTAssertEqual(speaker?.id, "test1") XCTAssertEqual(speaker?.id, "test1")
XCTAssertEqual(speaker?.name, "Test Speaker") XCTAssertEqual(speaker?.name, "Test Speaker")
XCTAssertEqual(speaker?.currentEmbedding, embedding) let expectedEmbedding = VDSPOperations.l2Normalize(embedding)
if let current = speaker?.currentEmbedding {
for (value, expected) in zip(current, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.0001)
}
} else {
XCTFail("Speaker embedding missing")
}
XCTAssertEqual(speaker?.duration, 5.0) XCTAssertEqual(speaker?.duration, 5.0)
} }
@@ -225,17 +232,25 @@ final class SpeakerOperationsTests: XCTestCase {
let oldEmb = [Float](repeating: 1.0, count: 256) let oldEmb = [Float](repeating: 1.0, count: 256)
let newEmb = [Float](repeating: 0.5, count: 256) // Use non-zero values to pass validation let newEmb = [Float](repeating: 0.5, count: 256) // Use non-zero values to pass validation
let alpha: Float = 0.7
let updated = SpeakerUtilities.updateEmbedding( let updated = SpeakerUtilities.updateEmbedding(
current: oldEmb, current: oldEmb,
new: newEmb, new: newEmb,
alpha: 0.7 alpha: alpha
) )
XCTAssertNotNil(updated) XCTAssertNotNil(updated)
// With alpha=0.7: result = 0.7 * 1.0 + 0.3 * 0.5 = 0.85 // Embeddings are averaged in normalized space and then renormalized.
if let updatedValues = updated { if let updatedValues = updated {
for value in updatedValues { let normalizedCurrent = VDSPOperations.l2Normalize(oldEmb)
XCTAssertEqual(value, 0.85, accuracy: 0.001) let normalizedNew = VDSPOperations.l2Normalize(newEmb)
var combined = [Float](repeating: 0, count: normalizedCurrent.count)
for i in 0..<combined.count {
combined[i] = alpha * normalizedCurrent[i] + (1 - alpha) * normalizedNew[i]
}
let expectedValues = VDSPOperations.l2Normalize(combined)
for (value, expected) in zip(updatedValues, expectedValues) {
XCTAssertEqual(value, expected, accuracy: 0.001)
} }
} }
} }
@@ -383,9 +398,14 @@ final class SpeakerOperationsTests: XCTestCase {
let average = SpeakerUtilities.averageEmbeddings([emb1, emb2, emb3]) let average = SpeakerUtilities.averageEmbeddings([emb1, emb2, emb3])
XCTAssertNotNil(average) XCTAssertNotNil(average)
// Average should be 2.0 // Average should reflect normalized mean of normalized embeddings.
for value in average! { let expected = VDSPOperations.l2Normalize([Float](repeating: 2.0, count: 256))
XCTAssertEqual(value, 2.0, accuracy: 0.001) if let average = average {
for (value, expectedValue) in zip(average, expected) {
XCTAssertEqual(value, expectedValue, accuracy: 0.001)
}
} else {
XCTFail("Average should not be nil")
} }
} }
@@ -402,10 +422,11 @@ final class SpeakerOperationsTests: XCTestCase {
// Should return average of valid embeddings only (emb1 in this case) // Should return average of valid embeddings only (emb1 in this case)
XCTAssertNotNil(average) XCTAssertNotNil(average)
XCTAssertEqual(average?.count, 256) XCTAssertEqual(average?.count, 256)
// Should be 1.0 since only emb1 is valid // Should match the normalized emb1 since only it is valid
if let avg = average { if let avg = average {
for value in avg { let expected = VDSPOperations.l2Normalize(emb1)
XCTAssertEqual(value, 1.0, accuracy: 0.001) for (value, expectedValue) in zip(avg, expected) {
XCTAssertEqual(value, expectedValue, accuracy: 0.001)
} }
} }
} }
+23 -12
View File
@@ -27,7 +27,10 @@ final class SpeakerTests: XCTestCase {
XCTAssertEqual(speaker.id, "test1") XCTAssertEqual(speaker.id, "test1")
XCTAssertEqual(speaker.name, "Alice") XCTAssertEqual(speaker.name, "Alice")
XCTAssertEqual(speaker.currentEmbedding, embedding) let expectedEmbedding = VDSPOperations.l2Normalize(embedding)
for (value, expected) in zip(speaker.currentEmbedding, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.0001)
}
XCTAssertEqual(speaker.duration, 5.0) XCTAssertEqual(speaker.duration, 5.0)
XCTAssertEqual(speaker.updateCount, 1) XCTAssertEqual(speaker.updateCount, 1)
XCTAssertTrue(speaker.rawEmbeddings.isEmpty) XCTAssertTrue(speaker.rawEmbeddings.isEmpty)
@@ -105,10 +108,11 @@ final class SpeakerTests: XCTestCase {
alpha: 0.8 // 80% old, 20% new alpha: 0.8 // 80% old, 20% new
) )
// The main embedding is recalculated as an average of raw embeddings after adding // The speaker stores embeddings in L2-normalized form. With a single raw embedding,
// Since we have only one raw embedding (0.5), the main embedding becomes 0.5 // the recalculated main embedding should equal the normalized segment embedding.
for value in speaker.currentEmbedding { let expectedEmbedding = VDSPOperations.l2Normalize(embedding2)
XCTAssertEqual(value, 0.5, accuracy: 0.001) for (value, expected) in zip(speaker.currentEmbedding, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.001)
} }
// Verify that the raw embedding was added // Verify that the raw embedding was added
@@ -132,7 +136,10 @@ final class SpeakerTests: XCTestCase {
) )
// Embedding should not have been updated (magnitude too low) // Embedding should not have been updated (magnitude too low)
XCTAssertEqual(speaker.currentEmbedding, embedding1) let expectedEmbedding = VDSPOperations.l2Normalize(embedding1)
for (value, expected) in zip(speaker.currentEmbedding, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.0001)
}
XCTAssertEqual(speaker.rawEmbeddings.count, 0) // No raw embedding added XCTAssertEqual(speaker.rawEmbeddings.count, 0) // No raw embedding added
XCTAssertEqual(speaker.updateCount, 1) // No update XCTAssertEqual(speaker.updateCount, 1) // No update
XCTAssertEqual(speaker.duration, 0.0) // Duration NOT updated due to early return XCTAssertEqual(speaker.duration, 0.0) // Duration NOT updated due to early return
@@ -166,7 +173,7 @@ final class SpeakerTests: XCTestCase {
// First embedding should be from pattern 10 (0-9 were removed) // First embedding should be from pattern 10 (0-9 were removed)
let firstEmbedding = speaker.rawEmbeddings.first?.embedding let firstEmbedding = speaker.rawEmbeddings.first?.embedding
let expectedFirst = createDistinctEmbedding(pattern: 10) let expectedFirst = VDSPOperations.l2Normalize(createDistinctEmbedding(pattern: 10))
if let firstValue = firstEmbedding?[0] { if let firstValue = firstEmbedding?[0] {
XCTAssertEqual(firstValue, expectedFirst[0], accuracy: 0.001) XCTAssertEqual(firstValue, expectedFirst[0], accuracy: 0.001)
} }
@@ -210,9 +217,10 @@ final class SpeakerTests: XCTestCase {
speaker.recalculateMainEmbedding() speaker.recalculateMainEmbedding()
// Average should be (1 + 2 + 3) / 3 = 2.0 // Raw embeddings are stored normalized; recalculating should keep the unit-normalized vector.
for value in speaker.currentEmbedding { let expectedEmbedding = VDSPOperations.l2Normalize([Float](repeating: 1.0, count: 256))
XCTAssertEqual(value, 2.0, accuracy: 0.001) for (value, expected) in zip(speaker.currentEmbedding, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.0001)
} }
} }
@@ -223,8 +231,11 @@ final class SpeakerTests: XCTestCase {
// No raw embeddings // No raw embeddings
speaker.recalculateMainEmbedding() speaker.recalculateMainEmbedding()
// Should keep original embedding // Should keep the previously normalized embedding
XCTAssertEqual(speaker.currentEmbedding, original) let expectedEmbedding = VDSPOperations.l2Normalize(original)
for (value, expected) in zip(speaker.currentEmbedding, expectedEmbedding) {
XCTAssertEqual(value, expected, accuracy: 0.0001)
}
} }
// MARK: - Speaker Merging Tests // MARK: - Speaker Merging Tests
+11
View File
@@ -0,0 +1,11 @@
Copyright:
* Until package version 1.1.23: © 2011 Daniel Müllner <https://danifold.net>
* All changes from version 1.1.24 on: © Google Inc. <https://www.google.com>
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but not
limited to compiled object code, generated documentation, and
conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021-2024 BUT Speech@FIT (original VBx project)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.