fix: clean up diarization test infrastructure (#395)

## Summary
- Extract shared fixture helpers into `DiarizationTestFixtures` enum,
removing ~200 lines of duplicate code across `LSEENDIntegrationTests`
and `SpeakerEnrollmentTests`
- Replace fragile `Mirror`-based private state inspection with
`internal` `hasActiveSession` property on `LSEENDDiarizerAPI`
- Fix non-deterministic `srand48` seed in `SortformerTests` (use
constant `42` instead of time-based seed)
- Fix asymmetric skip guards in Sortformer enrollment tests (`XCTSkipIf`
instead of `XCTAssertNotNil` for host-dependent segments)

## Test plan
- [x] `swift build --build-tests` passes
- [ ] `swift test --filter SortformerTests` passes
- [ ] `swift test --filter LSEENDIntegrationTests` passes
- [ ] `swift test --filter SpeakerEnrollmentTests` passes
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/fluidinference/fluidaudio/pull/395"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
This commit is contained in:
Alex
2026-03-18 12:51:34 -04:00
committed by GitHub
parent 32b4b4d4a8
commit 8aa0dfcdac
34 changed files with 1251 additions and 1685 deletions
+13 -123
View File
@@ -75,161 +75,51 @@ Use `OfflineDiarizerManager` when you need offline DER parity or want to run the
### Diarizer Protocol
`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, which provides a unified streaming and offline API.
`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, providing a unified streaming and offline API.
**Protocol Properties:**
- `isAvailable: Bool` — Whether the model is loaded and ready
- `numFramesProcessed: Int` — Confirmed frames processed so far
- `targetSampleRate: Int?` — Model's expected audio sample rate in Hz
- `modelFrameHz: Double?` — Output frame rate in Hz (frames per second)
- `numSpeakers: Int?` — Number of real speaker output tracks
- `timeline: DiarizerTimeline` — Accumulated diarization results
**Streaming:** `addAudio(_:sourceSampleRate:)``process()` → read `timeline`. Convenience `process(samples:sourceSampleRate:)` combines both steps. Returns `DiarizerTimelineUpdate?` (`nil` when not enough audio has accumulated).
**Streaming:**
- `addAudio<C: Collection>(_ samples: C, sourceSampleRate: Double?) throws` — Buffer audio for processing; pass a non-nil `sourceSampleRate` to resample on the fly
- `process() throws -> DiarizerTimelineUpdate?` — Run inference on buffered audio; returns `nil` if not enough audio has accumulated
- `process<C: Collection>(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate?` — Convenience combining `addAudio` + `process` in one call
**Offline:** `processComplete(_:sourceSampleRate:...)` or `processComplete(audioFileURL:...)` to process a full recording in one call.
**Offline:**
- `processComplete<C: Collection>(_ samples: C, sourceSampleRate:, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Process a complete audio buffer in one call
- `processComplete(audioFileURL: URL, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Read, resample, and process an audio file end-to-end
**Speaker Enrollment:** `enrollSpeaker(withAudio:sourceSampleRate:named:...)` feeds known-speaker audio before streaming to label a slot.
**Speaker Enrollment:**
- `enrollSpeaker<C: Collection>(withAudio samples: C, sourceSampleRate:, named:, overwritingAssignedSpeakerName:) throws -> DiarizerSpeaker?` — Feed audio of a known speaker before streaming begins; warms model state and labels that speaker's slot for subsequent `process()` calls
**Lifecycle:**
- `reset()` — Clear all streaming state (session, buffers, timeline) while keeping the model loaded
- `cleanup()` — Release all resources including the loaded model
**Lifecycle:** `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything.
---
### DiarizerTimeline
### DiarizerTimeline & DiarizerSpeaker
Holds accumulated streaming predictions and derived speaker segments. Returned by `Diarizer.timeline` and `processComplete(...)`.
`DiarizerTimeline` accumulates per-frame speaker probabilities and derives `DiarizerSpeaker` segments. Each speaker has `finalizedSegments` (confirmed) and `tentativeSegments` (may be revised). Segments expose `startTime`, `endTime`, `duration`, and `isFinalized`.
**Key Properties:**
- `config: DiarizerTimelineConfig` — Post-processing configuration used to build segments
- `speakers: [Int: DiarizerSpeaker]` — Speaker slots keyed by output track index
- `finalizedPredictions: [Float]` — Flat `[frames × numSpeakers]` array of finalized per-frame probabilities
- `tentativePredictions: [Float]` — Same layout; frames still within the right-context window that may be revised
- `numFinalizedFrames: Int` — Count of finalized frames
- `numTentativeFrames: Int` — Count of tentative frames
- `finalizedDuration: Float` — Duration in seconds of finalized audio
- `hasSegments: Bool` — Whether any speaker has at least one segment
**Mutation:**
- `addChunk(_ chunk: DiarizerChunkResult) throws -> DiarizerTimelineUpdate` — Append new predictions and rebuild segments; called internally by the diarizer
- `rebuild(finalizedPredictions:tentativePredictions:keepingSpeakers:isComplete:) throws` — Replace all predictions from scratch (used by offline processing)
- `reset(keepingSpeakers:)` / `reset(keepingSpeakersWhere:)` — Clear segments and optionally preserve named speakers or speaker metadata
- `finalize()` — Promote all tentative segments to finalized
**`DiarizerTimelineConfig`** — Shared configuration used by both diarizers:
| Parameter | Default | Description |
|---|---|---|
| `numSpeakers` | model-specific | Number of speaker output tracks |
| `frameDurationSeconds` | model-specific | Duration of one output frame |
| `onsetThreshold` | 0.5 | Probability threshold to begin a speech segment |
| `offsetThreshold` | 0.5 | Probability threshold to end a speech segment |
| `onsetPadFrames` | 0 | Frames prepended to each segment onset |
| `offsetPadFrames` | 0 | Frames appended to each segment offset |
| `minFramesOn` | 0 | Minimum segment length; shorter segments are dropped |
| `minFramesOff` | 0 | Minimum gap; shorter silences are closed |
| `maxStoredFrames` | nil | Rolling window cap on finalized frames (nil = unlimited) |
---
### DiarizerSpeaker
Represents a single speaker track within a `DiarizerTimeline`.
**Key Properties:**
- `id: UUID` — Stable identity across resets
- `index: Int` — Slot index in the diarizer output (0-based)
- `name: String?` — Optional display name (set via enrollment or manually)
- `finalizedSegments: [DiarizerSegment]` — Confirmed speech segments
- `tentativeSegments: [DiarizerSegment]` — Speculative segments within the right-context window
- `hasSegments: Bool` — Whether any finalized or tentative segments exist
- `numSpeechFrames: Int` — Total frames spanned by all segments (finalized + tentative)
- `speechDuration: Float` — Total speech duration in seconds
**`DiarizerSegment`** — A single time-range for one speaker:
- `startFrame / endFrame: Int` — Frame indices (convert using `frameDurationSeconds`)
- `startTime / endTime: Float` — Seconds
- `duration: Float` — Segment length in seconds
- `isFinalized: Bool` — Whether the segment has been confirmed
**`DiarizerTimelineConfig`** controls post-processing (onset/offset thresholds default to 0.5, min segment/gap duration, optional rolling window cap). Both diarizers accept this at init.
---
### SortformerDiarizer
End-to-end streaming diarization using NVIDIA's Sortformer model. Tracks **4 fixed speaker slots**.
Streaming diarization using NVIDIA's Sortformer. 4 fixed speaker slots, 16 kHz input, 80 ms frame duration.
- **Sample rate:** 16 kHz
- **Frame duration:** 80 ms (12.5 Hz output)
- **Streaming latency:** ~0.64 s (`default` config) or ~1.04 s (`nvidiaLowLatency` configs)
- **Accuracy:** 31.7% DER on AMI SDM (`nvidiaHighLatencyV2_1`; other configs untested)
**Initialization:**
```swift
// Preferred: download and compile model automatically
let diarizer = SortformerDiarizer(config: .default, timelineConfig: .sortformerDefault)
try await diarizer.initialize(mainModelPath: modelURL)
// Or with pre-loaded models
diarizer.initialize(models: sortformerModels)
```
**`SortformerConfig` Presets:**
| Preset | Latency | Notes |
|---|---|---|
| `.default` / `.fastestV2_1` | 1.04 s | Fastest inference speed and update rate. |
| `.fastestV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights. |
| `.nvidiaLowLatencyV2_1` | 1.04 s | Uses more context than `fastest`; Improvement is minimal |
| `.nvidiaLowLatencyV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights |
| `.nvidiaHighLatencyV2_1` | 30.4 s | 31.7% DER on AMI SDM |
| `.nvidiaHighLatencyV2_0` | 30.4 s | Uses NVIDIA's Sortformer v2 weights |
All streaming methods are defined by the `Diarizer` protocol above. Additionally:
- `state: SortformerStreamingState` — Live speaker cache and FIFO queue state (for diagnostics)
- `config: SortformerConfig` — The configuration this instance was created with
**Config presets:** `.default` / `.fastV2_1` (1.04 s latency), `.balancedV2_1` (1.04 s, 20.6% DER on AMI SDM), `.highContextV2_1` (30.4 s latency). v2 variants also available.
---
### LSEENDDiarizer
End-to-end streaming diarization using LS-EEND (Linear Streaming End-to-End Neural Diarization). Supports a **variable number of speaker slots** depending on the model variant.
Streaming diarization using LS-EEND. Variable speaker slots, 8 kHz input, 100 ms frame duration, 20.7% DER on AMI SDM.
- **Sample rate:** 8 kHz
- **Frame duration:** 100 ms (10 Hz output)
- **Accuracy:** 20.7% DER on AMI SDM (AMI variant)
- **Variants:** `LSEENDVariant` (`LSEENDModelDescriptor.LSEENDVariant`)
**Initialization:**
```swift
// Auto-download from HuggingFace
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
try await diarizer.initialize(variant: .dihard3)
// Or with an explicit descriptor
let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: .dihard3)
try diarizer.initialize(descriptor: descriptor)
```
**LS-EENDSpecific Properties:**
- `computeUnits: MLComputeUnits` — CoreML compute target (`.cpuOnly` is typically fastest)
- `streamingLatencySeconds: Double?` — Minimum audio required before first output frame
- `decodeMaxSpeakers: Int?` — Total output slots including internal boundary tracks
- `timelineConfig: DiarizerTimelineConfig` — Active post-processing configuration
**Variants:** ami, callhome, dihard2, dihard3 (via `LSEENDModelDescriptor.loadFromHuggingFace(variant:)`).
**Additional Method:**
- `finalizeSession() throws -> DiarizerChunkResult?` — Flush pending audio and finalize the timeline; call at end of a stream before reading the final timeline
**`LSEENDModelDescriptor`:**
- `LSEENDModelDescriptor.loadFromHuggingFace(variant:cacheDirectory:computeUnits:) async throws -> LSEENDModelDescriptor` — Download and cache all model files; returns a descriptor ready for `initialize(descriptor:)`
- `init(variant:modelURL:metadataURL:)` — Construct from local paths if already cached
All streaming and offline methods are defined by the `Diarizer` protocol above.
Call `finalizeSession()` at end-of-stream to flush pending audio before reading the final timeline.
## Voice Activity Detection
+1 -1
View File
@@ -312,7 +312,7 @@ let result: LSEENDInferenceResult = try engine.infer(audioFileURL: url)
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate)
// Or with a caller-owned mel spectrogram (for thread isolation)
let mel = NeMoMelSpectrogram(...)
let mel = AudioMelSpectrogram(...)
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate, melSpectrogram: mel)
```
+24 -22
View File
@@ -38,7 +38,7 @@ Audio (16kHz) → Mel Spectrogram → CoreML Model → Speaker Probabilities
The pipeline consists of:
1. **Mel Spectrogram** (`NeMoMelSpectrogram`): Converts raw audio to 128-bin mel features
1. **Mel Spectrogram** (`AudioMelSpectrogram`): Converts raw audio to 128-bin mel features
2. **CoreML Model** (`DiarizerInference`): Combined encoder + attention + head
3. **Streaming State** (`SortformerStreamingState`): Maintains speaker cache and FIFO queue
4. **Post-processing** (`SortformerTimeline`): Converts probabilities to speaker segments
@@ -79,8 +79,8 @@ FIFO Queue Role:
| Config | `fifoLen` | Effect |
|--------|-----------|--------|
| Default | 40 | Smaller memory, faster compression cycles |
| NVIDIA Low | 188 | Larger context before compression |
| NVIDIA High | 40 | Same as default |
| Balanced | 188 | Larger context before compression |
| High Context | 40 | Same as default |
When `fifoLen + newChunkFrames > fifoLen` capacity, frames are popped from FIFO and either:
1. Added to speaker cache (if speaker was active)
@@ -105,8 +105,8 @@ Chunk with Context:
| Config | `rightContext` | Look-ahead | Latency Impact |
|--------|----------------|------------|----------------|
| Default | 7 | 7 × 80ms = 560ms | Low latency |
| NVIDIA Low | 7 | 7 × 80ms = 560ms | Low latency |
| NVIDIA High | 40 | 40 × 80ms = 3.2s | High latency, better quality |
| Balanced | 7 | 7 × 80ms = 560ms | Low latency |
| High Context | 40 | 40 × 80ms = 3.2s | High latency, better quality |
**Why Right Context Matters:**
@@ -142,7 +142,7 @@ Default config:
= 13 × 8 × 0.01
= 1.04 seconds
NVIDIA High Latency config:
High Context config:
= (340 + 40) × 8 × 160 / 16000
= 380 × 8 × 0.01
= 30.4 seconds
@@ -191,11 +191,11 @@ Defines streaming parameters that must match the CoreML model's static shapes:
// Default (~1.04s latency, lowest latency)
SortformerConfig.default
// NVIDIA High Latency (30.4s latency, best quality)
SortformerConfig.nvidiaHighLatency
// Balanced (1.04s latency, best quality on AMI SDM)
SortformerConfig.balancedV2_1
// NVIDIA Low Latency (1.04s latency)
SortformerConfig.nvidiaLowLatency
// High Context (30.4s latency, most context)
SortformerConfig.highContextV2_1
```
### Pipeline.swift
@@ -375,9 +375,11 @@ public struct SortformerSegment {
| Config | Chunk Size | Latency | Quality |
|--------|------------|---------|---------|
| `default` | 6 frames | ~1.04s | Good |
| `nvidiaLowLatency` | 6 frames | ~1.04s | Better |
| `nvidiaHighLatency` | 340 frames | ~30.4s | Best |
| `default` / `fastV2_1` | 6 frames | ~1.04s | Good |
| `balancedV2_1` | 6 frames | ~1.04s | Best (20.6% DER on AMI SDM) |
| `highContextV2_1` | 340 frames | ~30.4s | Good (31.7% DER on AMI SDM) |
> **Note:** v2.1 variants may degrade when many speakers are talking simultaneously. v2 variants (`fastV2`, `balancedV2`, `highContextV2`) are available as alternatives.
Latency is determined by:
- `chunkLen * subsamplingFactor * melStride / sampleRate`
@@ -396,10 +398,10 @@ This preserves the most informative historical context while bounding memory usa
## Post-Processing
`SortformerPostProcessingConfig` controls segment extraction:
`DiarizerTimelineConfig` controls segment extraction:
```swift
let config = SortformerPostProcessingConfig(
let config = DiarizerTimelineConfig(
onsetThreshold: 0.5, // Probability to start speech
offsetThreshold: 0.5, // Probability to end speech
minDurationOn: 0.25, // Min speech segment (seconds)
@@ -414,8 +416,8 @@ Three CoreML models are available on HuggingFace:
| Variant | File | Config |
|---------|------|--------|
| Default | `Sortformer.mlmodelc` | `SortformerConfig.default` |
| NVIDIA Low | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.nvidiaLowLatency` |
| NVIDIA High | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.nvidiaHighLatency` |
| Balanced | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.balancedV2_1` |
| High Context | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.highContextV2_1` |
**Important:** Each model has baked-in static shapes. You must use the matching configuration.
@@ -447,8 +449,8 @@ audioEngine.installTap { buffer in
### Batch Processing
```swift
let diarizer = SortformerDiarizer(config: .nvidiaHighLatency)
let models = try await SortformerModels.loadFromHuggingFace(config: .nvidiaHighLatency)
let diarizer = SortformerDiarizer(config: .highContextV2_1)
let models = try await SortformerModels.loadFromHuggingFace(config: .highContextV2_1)
diarizer.initialize(models: models)
let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_000)
@@ -457,9 +459,9 @@ let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_0
let fileTimeline = try diarizer.processComplete(audioFileURL: audioURL)
// Get segments per speaker
for (speakerIndex, segments) in timeline.segments.enumerated() {
for segment in segments {
print("Speaker \(speakerIndex): \(segment.startTime)s - \(segment.endTime)s")
for (index, speaker) in timeline.speakers {
for segment in speaker.finalizedSegments {
print("Speaker \(index): \(segment.startTime)s - \(segment.endTime)s")
}
}
```
+3 -38
View File
@@ -335,25 +335,7 @@ swift run fluidaudiocli diarization-benchmark --mode offline --auto-download \
### LS-EEND (LongForm Streaming End-to-End Neural Diarization)
State-of-the-art end-to-end streaming diarization with fully local CoreML inference. This is the best default option for online and streaming diarization when you want low-latency speaker activity updates from a single model without a separate clustering pipeline. Multiple exported variants are available for different benchmark domains, and the diarizer supports both streaming and complete-buffer processing.
Why use LS-EEND:
- Robust to noise and high speaker overlap
- Supports up to 10 speakers in a session
- Generally achieves better benchmark results than Sortformer on CALLHOME, AMI, and DIHARD III
- Frame-by-frame inference with 100ms frames, and 900ms of tentative frames.
- Much more lightweight than Sortformer, and can be faster on CPU than Sortformer on ANE
- Does not suffer from the missed-quiet-speech issues that Sortformer can show
- End-to-end model without separate segmentation, embedding extraction, and clustering stages
- Also works well for complete-buffer inference when you want the same model in offline mode
Tradeoffs:
- Can pick up background speakers more readily than Sortformer
- Speaker identity stability is usually somewhat weaker than Sortformer
- Speaker indices are expected to follow chronological arrival order, but Sortformer tends to maintain that ordering more reliably
- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline
See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for model loading and integration details.
End-to-end streaming diarization with CoreML inference. Default choice for online diarization — single model, no clustering pipeline, up to 10 speakers, 100ms frame updates with 900ms tentative preview. Supports both streaming and complete-buffer processing. See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for details.
```swift
import FluidAudio
@@ -373,26 +355,9 @@ Task {
### Sortformer (End-to-End Neural Diarization)
End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). This is the secondary streaming/online diarizer behind LS-EEND. It offers nearly the same live-update workflow and API flexibility, but trades LS-EEND's stronger benchmark results and higher speaker capacity for better identity stability. No separate VAD, segmentation, or clustering needed. Limited to 4 speakers and does not remember speakers across recordings. Licensed under NVIDIA Open Model License (no restrictions).
End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). Secondary streaming diarizer — trades LS-EEND's higher speaker capacity and benchmark results for better speaker identity stability. Limited to 4 speakers. No separate VAD, segmentation, or clustering needed. Licensed under NVIDIA Open Model License.
Why use Sortformer:
- Simple streaming pipeline with low latency
- Supports live-update flexibility like LS-EEND
- Handles overlapping speech without an external clustering stage
- Better speaker identity stability than LS-EEND
- Ignores speech from background conversations
- Good choice when the 4-speaker limit is acceptable
Tradeoffs:
- Limited to 4 unique speakers
- Updates are performed less frequently than for LS-EEND
- Can get overwhelmed when there are too many background speakers
- Can miss speech when the audio is too quiet
- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline
Like LS-EEND, Sortformer supports ultra-low-latency inference with 0.5s updates, 0.5s preview frames, and a fixed 1.0s latency before finalized predictions. Both models emit their results into a `DiarizerTimeline`, so you do not need to manage incremental diarization state externally.
See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage, comparison with Pyannote, streaming config, and architecture details.
Both LS-EEND and Sortformer emit results into a `DiarizerTimeline` with ultra-low-latency updates. See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage and comparison.
### Streaming/Online Speaker Diarization (Pyannote)
+1 -1
View File
@@ -82,7 +82,7 @@ HF_TOKEN="your_token" python nemo_ami_benchmark.py --output results.json
### High-Latency Streaming Config
These settings match the Swift `SortformerConfig.nvidiaHighLatency`:
These settings match the Swift `SortformerConfig.highContextV2_1`:
| Parameter | Value | Description |
|-----------|-------|-------------|
@@ -172,7 +172,7 @@ public actor StreamingEouAsrManager {
private var rnntDecoder: RnntDecoder?
private let audioConverter = AudioConverter()
private var tokenizer: Tokenizer?
private let melProcessor = NeMoMelSpectrogram() // Native Swift mel spectrogram
private let melProcessor = AudioMelSpectrogram() // Native Swift mel spectrogram
// Configuration - now based on chunkSize
public let chunkSize: StreamingChunkSize
@@ -247,7 +247,7 @@ public actor StreamingEouAsrManager {
public func loadModels(modelDir: URL) async throws {
logger.info("Loading CoreML models from \(modelDir.path)...")
// No longer loading preprocessor - using native Swift NeMoMelSpectrogram instead
// No longer loading preprocessor - using native Swift AudioMelSpectrogram instead
self.streamingEncoder = try await MLModel.load(
contentsOf: modelDir.appendingPathComponent("streaming_encoder.mlmodelc"), configuration: self.configuration
)
@@ -450,7 +450,7 @@ public actor StreamingEouAsrManager {
let mel = try MLMultiArray(shape: [1, 128, NSNumber(value: numFrames)], dataType: .float32)
let melPtr = mel.dataPointer.bindMemory(to: Float.self, capacity: mel.count)
// NeMoMelSpectrogram returns [nMels, T] row-major (mel bin, then time)
// AudioMelSpectrogram returns [nMels, T] row-major (mel bin, then time)
// CoreML expects [1, 128, T] which is the same layout
melPtr.update(from: melFlat, count: melFlat.count)
@@ -648,16 +648,6 @@ public final class DiarizerTimeline {
queue.sync { _tentativePredictions }
}
/// Total number of frames (finalized + tentative)
@available(
*, deprecated,
message: "`numFrames` now includes tentative frames. Use 'numFinalizedFrames' for only finalized frames.",
renamed: "numFinalizedFrames"
)
public var numFrames: Int {
queue.sync { _numFinalizedFrames + _tentativePredictions.count / config.numSpeakers }
}
/// Total number of finalized frames
public var numFinalizedFrames: Int {
queue.sync { _numFinalizedFrames }
@@ -677,31 +667,11 @@ public final class DiarizerTimeline {
speakers.values.contains(where: \.hasSegments)
}
/// Duration of all predictions in seconds
@available(
*, deprecated,
message: "`duration` now includes tentative frames. Use 'finalizedDuration' for only finalized frames.",
renamed: "finalizedDuration"
)
public var duration: Float {
Float(numFrames) * config.frameDurationSeconds
}
/// Duration of finalized predictions in seconds
public var finalizedDuration: Float {
Float(numFinalizedFrames) * config.frameDurationSeconds
}
/// Duration of tentative predictions in seconds
@available(
*, deprecated,
message: "tentativeDuration now excludes finalized frames. Use 'duration' for the full timeline duration.",
renamed: "duration"
)
public var tentativeDuration: Float {
Float(numTentativeFrames) * config.frameDurationSeconds
}
private var _finalizedPredictions: [Float] = []
private var _tentativePredictions: [Float] = []
private var _speakers: [Int: DiarizerSpeaker] = [:]
@@ -1106,42 +1076,6 @@ public final class DiarizerTimeline {
}
}
extension DiarizerTimeline {
@available(*, deprecated, renamed: "finalizedPredictions")
public var framePredictions: [Float] { finalizedPredictions }
@available(
*, deprecated,
message: "Use Timeline.speakers[index].confirmedSegments to access a speaker's confirmed segments."
)
public var segments: [[DiarizerSegment]] {
queue.sync {
var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers)
for (index, speaker) in _speakers {
result[index] = speaker.finalizedSegments
}
return result
}
}
@available(
*, deprecated,
message: "Use Timeline.speakers[index].tentativeSegments to access a speaker's tentative segments."
)
public var tentativeSegments: [[DiarizerSegment]] {
queue.sync {
var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers)
for (index, speaker) in _speakers {
result[index] = speaker.tentativeSegments
}
return result
}
}
@available(*, deprecated, renamed: "numTentativeFrames")
public var numTentative: Int { numTentativeFrames }
}
// MARK: - Timeline Update
public struct DiarizerTimelineUpdate: Sendable {
@@ -80,11 +80,18 @@ public final class LSEENDDiarizer: Diarizer {
return _engine?.decodeMaxSpeakers
}
/// Whether a streaming session is currently active.
var hasActiveSession: Bool {
lock.lock()
defer { lock.unlock() }
return _session != nil
}
// MARK: - Private State
private var _engine: LSEENDInferenceHelper?
private var _session: LSEENDStreamingSession?
private var _melSpectrogram: NeMoMelSpectrogram?
private var _melSpectrogram: AudioMelSpectrogram?
private var _timeline: DiarizerTimeline
private var _numFramesProcessed: Int = 0
private var _timelineConfig: DiarizerTimelineConfig
@@ -737,8 +744,8 @@ public final class LSEENDDiarizer: Diarizer {
}
/// Create a new mel spectrogram instance owned by this diarizer.
private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> NeMoMelSpectrogram {
NeMoMelSpectrogram(
private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> AudioMelSpectrogram {
AudioMelSpectrogram(
sampleRate: featureConfig.sampleRate,
nMels: featureConfig.nMels,
nFFT: featureConfig.nFFT,
@@ -54,7 +54,7 @@ private final class LSEENDInferenceSharedResources {
let modelFrameHz: Double
let streamingLatencySeconds: Double
let decodeMaxSpeakers: Int
let melSpectrogram: NeMoMelSpectrogram
let melSpectrogram: AudioMelSpectrogram
let offlineFeatureExtractor: LSEENDOfflineFeatureExtractor
// Preallocated ANE-aligned input arrays reused across predictStep calls
@@ -77,7 +77,7 @@ private final class LSEENDInferenceSharedResources {
modelFrameHz = metadata.frameHz
streamingLatencySeconds = metadata.streamingLatencySeconds
decodeMaxSpeakers = metadata.maxNspks
melSpectrogram = NeMoMelSpectrogram(
melSpectrogram = AudioMelSpectrogram(
sampleRate: featureConfig.sampleRate,
nMels: featureConfig.nMels,
nFFT: featureConfig.nFFT,
@@ -150,7 +150,7 @@ public final class LSEENDInferenceHelper {
/// Maximum number of speaker slots in the model output (including boundary tracks).
public var decodeMaxSpeakers: Int { sharedResources.decodeMaxSpeakers }
fileprivate var melSpectrogram: NeMoMelSpectrogram { sharedResources.melSpectrogram }
fileprivate var melSpectrogram: AudioMelSpectrogram { sharedResources.melSpectrogram }
private var offlineFeatureExtractor: LSEENDOfflineFeatureExtractor { sharedResources.offlineFeatureExtractor }
private let lock = NSLock()
@@ -191,8 +191,9 @@ public final class LSEENDInferenceHelper {
/// - inputSampleRate: Must match ``targetSampleRate``.
/// - melSpectrogram: A mel spectrogram instance owned by the caller.
/// - Returns: A session that accepts audio via ``LSEENDStreamingSession/pushAudio(_:)``.
public func createSession(inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram) throws -> LSEENDStreamingSession
{
public func createSession(
inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram
) throws -> LSEENDStreamingSession {
try LSEENDStreamingSession(engine: self, inputSampleRate: inputSampleRate, melSpectrogram: melSpectrogram)
}
@@ -472,7 +473,7 @@ public final class LSEENDStreamingSession {
fileprivate var emittedFrames = 0
fileprivate init(
engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram? = nil
engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram? = nil
) throws {
guard inputSampleRate == engine.targetSampleRate else {
throw LSEENDError.unsupportedAudio(
@@ -45,8 +45,8 @@ public struct LSEENDFeatureConfig: Sendable, Hashable {
}
}
private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpectrogram {
NeMoMelSpectrogram(
private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> AudioMelSpectrogram {
AudioMelSpectrogram(
sampleRate: config.sampleRate,
nMels: config.nMels,
nFFT: config.nFFT,
@@ -70,14 +70,14 @@ private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpe
/// For incremental processing, use ``LSEENDStreamingFeatureExtractor`` instead.
public final class LSEENDOfflineFeatureExtractor {
private let config: LSEENDFeatureConfig
private let spectrogram: NeMoMelSpectrogram
private let spectrogram: AudioMelSpectrogram
/// Creates an offline feature extractor.
///
/// - Parameters:
/// - metadata: Model metadata from which feature parameters are derived.
/// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`.
public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) {
public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) {
let featureConfig = LSEENDFeatureConfig(metadata: metadata)
config = featureConfig
self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig)
@@ -182,7 +182,7 @@ public final class LSEENDOfflineFeatureExtractor {
/// - Important: This class is **not** thread-safe. All calls must be serialized externally.
public final class LSEENDStreamingFeatureExtractor {
private let config: LSEENDFeatureConfig
private let spectrogram: NeMoMelSpectrogram
private let spectrogram: AudioMelSpectrogram
private var audioBuffer: [Float] = []
private var audioStartSample = 0
@@ -201,7 +201,7 @@ public final class LSEENDStreamingFeatureExtractor {
/// - Parameters:
/// - metadata: Model metadata from which feature parameters are derived.
/// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`.
public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) {
public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) {
let featureConfig = LSEENDFeatureConfig(metadata: metadata)
config = featureConfig
self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig)
@@ -409,12 +409,6 @@ public struct OfflineDiarizerConfig: Sendable {
set { embedding.excludeOverlap = newValue }
}
@available(*, deprecated, renamed: "embeddingExcludeOverlap")
public var shouldExcludeOverlaps: Bool {
get { embeddingExcludeOverlap }
set { embeddingExcludeOverlap = newValue }
}
public var minSegmentDuration: Double {
get { embedding.minSegmentDurationSeconds }
set { embedding.minSegmentDurationSeconds = newValue }
@@ -65,7 +65,7 @@ public final class SortformerDiarizer: Diarizer {
private var _models: SortformerModels?
// Native mel spectrogram (used when useNativePreprocessing is enabled)
private let melSpectrogram = NeMoMelSpectrogram()
private let melSpectrogram = AudioMelSpectrogram()
// Audio buffering
private var audioBuffer: [Float] = []
@@ -80,17 +80,6 @@ public final class SortformerDiarizer: Diarizer {
// MARK: - Initialization
@available(
*, deprecated, renamed: "init(config:timelineConfig:)",
message: "Use `timelineConfig` instead of `postProcessingConfig`."
)
public convenience init(
config: SortformerConfig = .default,
postProcessingConfig: DiarizerTimelineConfig = .default(numSpeakers: 4, frameDurationSeconds: 0.08)
) {
self.init(config: config, timelineConfig: postProcessingConfig)
}
public init(
config: SortformerConfig = .default,
timelineConfig: DiarizerTimelineConfig = .sortformerDefault
@@ -103,12 +92,6 @@ public final class SortformerDiarizer: Diarizer {
self._timeline = DiarizerTimeline(config: timelineConfig)
}
/// Backward-compatible initializer accepting the legacy SortformerPostProcessingConfig.
@available(*, deprecated, message: "Use init(config:timelineConfig:) with DiarizerTimelineConfig instead")
public convenience init(config: SortformerConfig = .default, postProcessingConfig: SortformerPostProcessingConfig) {
self.init(config: config, timelineConfig: postProcessingConfig.toDiarizerConfig())
}
/// Initialize with CoreML models (combined pipeline mode).
///
/// - Parameters:
@@ -399,18 +382,6 @@ public final class SortformerDiarizer: Diarizer {
///
/// Convenience method that combines `addAudio()` and `process()`.
///
/// - Parameters:
/// - samples: Audio samples (16kHz mono)
/// - sourceSampleRate: Source audio sample rate
/// - Returns: New chunk results if enough audio was processed
@available(*, deprecated, renamed: "process(samples:)")
public func processSamples(
_ samples: [Float],
sourceSampleRate: Double? = nil
) throws -> DiarizerTimelineUpdate? {
return try process(samples: samples, sourceSampleRate: sourceSampleRate)
}
/// Process a chunk of audio in one call.
///
/// Convenience method that combines `addAudio()` and `process()`.
@@ -121,9 +121,10 @@ public struct SortformerConfig: Sendable {
spkcacheUpdatePeriod: 31
)
/// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2 weights
public static let `fastestV2` = SortformerConfig(
modelVariant: .fastestV2,
/// Fast config with Sortformer v2 weights (~1.04s latency, smallest context).
/// May handle high-speaker-count scenarios better than v2.1 (v2.1 can degrade when many speakers overlap).
public static let `fastV2` = SortformerConfig(
modelVariant: .fastV2,
chunkLen: 6,
chunkLeftContext: 1,
chunkRightContext: 7,
@@ -132,9 +133,10 @@ public struct SortformerConfig: Sendable {
spkcacheUpdatePeriod: 31
)
/// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2.1 weights
public static let `fastestV2_1` = SortformerConfig(
modelVariant: .fastestV2_1,
/// Fast config with Sortformer v2.1 weights (~1.04s latency, smallest context).
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
public static let `fastV2_1` = SortformerConfig(
modelVariant: .fastV2_1,
chunkLen: 6,
chunkLeftContext: 1,
chunkRightContext: 7,
@@ -143,39 +145,10 @@ public struct SortformerConfig: Sendable {
spkcacheUpdatePeriod: 31
)
/// Backwards compatible alias for NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights
@available(*, deprecated, renamed: "nvidiaHighLatencyV2_1")
public static let nvidiaHighLatency = nvidiaHighLatencyV2_1
/// NVIDIA's 30.4s latency configuration with Sortformer v2 weights
public static let nvidiaHighLatencyV2 = SortformerConfig(
modelVariant: .nvidiaHighLatencyV2,
chunkLen: 340,
chunkLeftContext: 1,
chunkRightContext: 40,
fifoLen: 40,
spkcacheLen: 188,
spkcacheUpdatePeriod: 300
)
/// NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights
public static let nvidiaHighLatencyV2_1 = SortformerConfig(
modelVariant: .nvidiaHighLatencyV2_1,
chunkLen: 340,
chunkLeftContext: 1,
chunkRightContext: 40,
fifoLen: 40,
spkcacheLen: 188,
spkcacheUpdatePeriod: 300
)
/// Backwards compatible alias for NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights
@available(*, deprecated, renamed: "nvidiaLowLatencyV2_1")
public static let nvidiaLowLatency = nvidiaLowLatencyV2_1
/// NVIDIA's 1.04s latency configuration with Sortformer v2 weights
public static let nvidiaLowLatencyV2 = SortformerConfig(
modelVariant: .nvidiaLowLatencyV2,
/// Balanced config with Sortformer v2 weights (~1.04s latency, larger FIFO for better quality).
/// 20.57% DER on AMI SDM. May handle high-speaker-count scenarios better than v2.1.
public static let balancedV2 = SortformerConfig(
modelVariant: .balancedV2,
chunkLen: 6,
chunkLeftContext: 1,
chunkRightContext: 7,
@@ -184,9 +157,11 @@ public struct SortformerConfig: Sendable {
spkcacheUpdatePeriod: 144
)
/// NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights (20.57% DER on AMI SDM)
public static let nvidiaLowLatencyV2_1 = SortformerConfig(
modelVariant: .nvidiaLowLatencyV2_1,
/// Balanced config with Sortformer v2.1 weights (~1.04s latency, larger FIFO for better quality).
/// 20.57% DER on AMI SDM.
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
public static let balancedV2_1 = SortformerConfig(
modelVariant: .balancedV2_1,
chunkLen: 6,
chunkLeftContext: 1,
chunkRightContext: 7,
@@ -195,9 +170,33 @@ public struct SortformerConfig: Sendable {
spkcacheUpdatePeriod: 144
)
/// High-context config with Sortformer v2 weights (~30.4s latency, most context window).
/// May handle high-speaker-count scenarios better than v2.1.
public static let highContextV2 = SortformerConfig(
modelVariant: .highContextV2,
chunkLen: 340,
chunkLeftContext: 1,
chunkRightContext: 40,
fifoLen: 40,
spkcacheLen: 188,
spkcacheUpdatePeriod: 300
)
/// High-context config with Sortformer v2.1 weights (~30.4s latency, most context window).
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
public static let highContextV2_1 = SortformerConfig(
modelVariant: .highContextV2_1,
chunkLen: 340,
chunkLeftContext: 1,
chunkRightContext: 40,
fifoLen: 40,
spkcacheLen: 188,
spkcacheUpdatePeriod: 300
)
/// - Warning: If you don't use one of the default configurations, you must use a local model converted with that configuration.
public init(
modelVariant: ModelVariant? = .fastestV2_1,
modelVariant: ModelVariant? = .fastV2_1,
chunkLen: Int = 6,
chunkLeftContext: Int = 1,
chunkRightContext: Int = 7,
@@ -239,131 +238,6 @@ public struct SortformerConfig: Sendable {
}
}
/// Configuration for post-processing Sortformer diarizer predictions
@available(*, deprecated, message: "Use DiarizerTimelineConfig instead", renamed: "DiarizerTimelineConfig")
public struct SortformerPostProcessingConfig {
/// Onset threshold for detecting the beginning and end of a speech
public var onsetThreshold: Float
/// Offset threshold for detecting the end of a speech
public var offsetThreshold: Float
/// Adding frames before each speech segment
public var onsetPadFrames: Int
/// Adding frames after each speech segment
public var offsetPadFrames: Int
/// Threshold for short speech segment deletion in frames
public var minFramesOn: Int
/// Threshold for small non-speech deletion in frames
public var minFramesOff: Int
/// Adding durations before each speech segment
public var onsetPadSeconds: Float {
get { Float(onsetPadFrames) * frameDurationSeconds }
set { onsetPadFrames = Int(round(newValue / frameDurationSeconds)) }
}
/// Adding durations after each speech segment
public var offsetPadSeconds: Float {
get { Float(offsetPadFrames) * frameDurationSeconds }
set { offsetPadFrames = Int(round(newValue / frameDurationSeconds)) }
}
/// Threshold for short speech segment deletion (seconds)
public var minDurationOn: Float {
get { Float(minFramesOn) * frameDurationSeconds }
set { minFramesOn = Int(round(newValue / frameDurationSeconds)) }
}
/// Threshold for small non-speech deletion (seconds)
public var minDurationOff: Float {
get { Float(minFramesOff) * frameDurationSeconds }
set { minFramesOff = Int(round(newValue / frameDurationSeconds)) }
}
/// Maximum number of predictions to retain
public var maxStoredFrames: Int? = nil
/// Number of speakers
public let numSpeakers: Int = 4
/// Number of speakers
public let frameDurationSeconds: Float = 0.08
/// Default configurations
public static var `default`: SortformerPostProcessingConfig {
SortformerPostProcessingConfig(
onsetThreshold: 0.5,
offsetThreshold: 0.5,
onsetPadFrames: 0,
offsetPadFrames: 0,
minFramesOn: 0,
minFramesOff: 0
)
}
public init(
onsetThreshold: Float = 0.5,
offsetThreshold: Float = 0.5,
onsetPadSeconds: Float = 0,
offsetPadSeconds: Float = 0,
minDurationOn: Float = 0,
minDurationOff: Float = 0,
maxStoredFrames: Int? = nil
) {
self.onsetThreshold = onsetThreshold
self.offsetThreshold = offsetThreshold
self.onsetPadFrames = Int(round(onsetPadSeconds / frameDurationSeconds))
self.offsetPadFrames = Int(round(offsetPadSeconds / frameDurationSeconds))
self.minFramesOn = Int(round(minDurationOn / frameDurationSeconds))
self.minFramesOff = Int(round(minDurationOff / frameDurationSeconds))
self.maxStoredFrames = maxStoredFrames
}
public init(
onsetThreshold: Float = 0.5,
offsetThreshold: Float = 0.5,
onsetPadFrames: Int = 0,
offsetPadFrames: Int = 0,
minFramesOn: Int = 0,
minFramesOff: Int = 0,
maxStoredFrames: Int? = nil
) {
self.onsetThreshold = onsetThreshold
self.offsetThreshold = offsetThreshold
self.onsetPadFrames = onsetPadFrames
self.offsetPadFrames = offsetPadFrames
self.minFramesOn = minFramesOn
self.minFramesOff = minFramesOff
self.maxStoredFrames = maxStoredFrames
}
/// Convert to the unified DiarizerPostProcessingConfig
public func toDiarizerConfig() -> DiarizerTimelineConfig {
DiarizerTimelineConfig(
numSpeakers: numSpeakers,
frameDurationSeconds: frameDurationSeconds,
onsetThreshold: onsetThreshold,
offsetThreshold: offsetThreshold,
onsetPadFrames: onsetPadFrames,
offsetPadFrames: offsetPadFrames,
minFramesOn: minFramesOn,
minFramesOff: minFramesOff,
maxStoredFrames: maxStoredFrames
)
}
}
/// Backward-compatible typealiases SortformerTimeline is now DiarizerTimeline
@available(*, deprecated, renamed: "DiarizerTimeline")
public typealias SortformerTimeline = DiarizerTimeline
@available(*, deprecated, renamed: "DiarizerSegment")
public typealias SortformerSegment = DiarizerSegment
// MARK: - Streaming State
/// State maintained across streaming chunks for Sortformer diarization.
@@ -455,7 +329,7 @@ public struct SortformerFeatureLoader: Sendable {
self.startFeat = 0
self.endFeat = 0
(self.featSeq, self.featLength, self.featSeqLength) = NeMoMelSpectrogram().computeFlatTransposed(audio: audio)
(self.featSeq, self.featLength, self.featSeqLength) = AudioMelSpectrogram().computeFlatTransposed(audio: audio)
// numChunks accounts for right context requirement: need endFeat + rc <= featLength
// Chunk n has endFeat = (n+1) * chunkLen, so valid when (n+1) * chunkLen + rc <= featLength
// numChunks = floor((featLength - rc) / chunkLen)
@@ -522,10 +396,6 @@ public struct StreamingUpdateResult: Sendable {
}
}
/// Backward-compatible typealias SortformerChunkResult is now DiarizerChunkResult
@available(*, deprecated, renamed: "DiarizerChunkResult")
public typealias SortformerChunkResult = DiarizerChunkResult
// MARK: - Errors
public enum SortformerError: Error, LocalizedError {
+113 -54
View File
@@ -381,14 +381,15 @@ public class DownloadUtils {
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedFilePath)
let request = authorizedRequest(url: fileURL)
let tempFileURL: URL
let httpResponse: HTTPURLResponse
if let handler = progressHandler {
let baseBytes = completedBytes
let fileCount = filesToDownload.count
let totalBytesSnapshot = totalBytes
httpResponse = try await downloadWithProgress(
(tempFileURL, httpResponse) = try await downloadWithProgress(
request: request,
destinationURL: destPath,
onProgress: { bytesWritten, _ in
guard totalBytesSnapshot > 0 else { return }
let current = baseBytes + bytesWritten
@@ -406,41 +407,11 @@ public class DownloadUtils {
guard let resp = response as? HTTPURLResponse else {
throw HuggingFaceDownloadError.invalidResponse
}
tempFileURL = url
httpResponse = resp
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
message: "Rate limited while downloading \(file.path)")
}
guard (200..<300).contains(httpResponse.statusCode) else {
throw HuggingFaceDownloadError.downloadFailed(
path: file.path,
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
)
}
// Remove existing file if present (handles parallel download race conditions)
if FileManager.default.fileExists(atPath: destPath.path) {
try? FileManager.default.removeItem(at: destPath)
}
try FileManager.default.moveItem(at: url, to: destPath)
completedBytes += Int64(max(0, file.size))
if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 {
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files")
}
// Report completed-file progress
progressHandler?(
DownloadProgress(
fractionCompleted: totalBytes > 0
? 0.5 * Double(completedBytes) / Double(totalBytes)
: 0.5 * Double(index + 1) / Double(filesToDownload.count),
phase: .downloading(completedFiles: index + 1, totalFiles: filesToDownload.count)
))
continue
}
// Validate response
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
@@ -454,13 +425,18 @@ public class DownloadUtils {
)
}
// Move downloaded file to destination
if FileManager.default.fileExists(atPath: destPath.path) {
try? FileManager.default.removeItem(at: destPath)
}
try FileManager.default.moveItem(at: tempFileURL, to: destPath)
completedBytes += Int64(max(0, file.size))
if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 {
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files")
}
// Report completed-file progress
progressHandler?(
DownloadProgress(
fractionCompleted: totalBytes > 0
@@ -485,16 +461,17 @@ public class DownloadUtils {
/// Download a single file using a delegate to get byte-level progress.
///
/// This is a pure transport helper the caller is responsible for validating
/// the HTTP status and moving the temporary file to its final destination.
///
/// - Parameters:
/// - request: The URLRequest to download.
/// - onProgress: Called with `(bytesWritten, totalBytesExpected)` as data arrives.
/// - Parameter destinationURL: Final destination path for the downloaded file.
/// - Returns: The HTTP response for the completed download.
/// - onProgress: Called with `(totalBytesWritten, totalBytesExpected)` as data arrives.
/// - Returns: The temporary file URL and HTTP response.
private static func downloadWithProgress(
request: URLRequest,
destinationURL: URL,
onProgress: @escaping @Sendable (Int64, Int64) -> Void
) async throws -> HTTPURLResponse {
) async throws -> (URL, HTTPURLResponse) {
let delegate = DownloadProgressDelegate(onProgress: onProgress)
// Dedicated session with delegate one per download to avoid cross-talk.
let session = URLSession(
@@ -510,25 +487,107 @@ public class DownloadUtils {
throw HuggingFaceDownloadError.invalidResponse
}
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
message: "HTTP \(httpResponse.statusCode)"
)
return (tempURL, httpResponse)
}
/// Download a specific subdirectory from a HuggingFace repository.
///
/// Use this for optional model components that aren't part of the required model set
/// (e.g., the Mimi encoder for PocketTTS voice cloning).
///
/// - Parameters:
/// - repo: The HuggingFace repository.
/// - subdirectory: Path within the repo to download (e.g. `"mimi_encoder.mlmodelc"`).
/// - repoDirectory: Local directory corresponding to the repo root.
/// Files are saved at `repoDirectory/<remote_path>`.
public static func downloadSubdirectory(
_ repo: Repo,
subdirectory: String,
to repoDirectory: URL
) async throws {
var filesToDownload: [(path: String, size: Int)] = []
func listFiles(at path: String) async throws {
let dirURL = try ModelRegistry.apiModels(repo.remotePath, "tree/main/\(path)")
let (dirData, response) = try await fetchWithAuth(from: dirURL)
if let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 429 || httpResponse.statusCode == 503
{
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
message: "Rate limited while listing files in \(path)")
}
guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else {
return
}
for item in items {
guard let itemPath = item["path"] as? String,
let itemType = item["type"] as? String
else { continue }
if itemType == "directory" {
try await listFiles(at: itemPath)
} else if itemType == "file" {
let fileSize = item["size"] as? Int ?? -1
filesToDownload.append((path: itemPath, size: fileSize))
}
}
}
guard (200..<300).contains(httpResponse.statusCode) else {
throw HuggingFaceDownloadError.downloadFailed(
path: destinationURL.lastPathComponent,
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
try await listFiles(at: subdirectory)
logger.info("Found \(filesToDownload.count) files in \(subdirectory)")
for (index, file) in filesToDownload.enumerated() {
let destPath = repoDirectory.appendingPathComponent(file.path)
if FileManager.default.fileExists(atPath: destPath.path) {
continue
}
try FileManager.default.createDirectory(
at: destPath.deletingLastPathComponent(),
withIntermediateDirectories: true
)
if file.size == 0 {
FileManager.default.createFile(atPath: destPath.path, contents: Data())
continue
}
let encodedPath =
file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedPath)
let request = authorizedRequest(url: fileURL)
let (tempURL, response) = try await sharedSession.download(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw HuggingFaceDownloadError.invalidResponse
}
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
message: "Rate limited while downloading \(file.path)")
}
guard (200..<300).contains(httpResponse.statusCode) else {
throw HuggingFaceDownloadError.downloadFailed(
path: file.path,
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
)
}
if FileManager.default.fileExists(atPath: destPath.path) {
try? FileManager.default.removeItem(at: destPath)
}
try FileManager.default.moveItem(at: tempURL, to: destPath)
if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 {
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) \(subdirectory) files")
}
}
if FileManager.default.fileExists(atPath: destinationURL.path) {
try? FileManager.default.removeItem(at: destinationURL)
}
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
return httpResponse
logger.info("Downloaded \(subdirectory) from \(repo.folderName)")
}
/// Fetch a single file from HuggingFace with retry
+25 -25
View File
@@ -240,44 +240,44 @@ public enum ModelNames {
/// Sortformer streaming diarization model names
public enum Sortformer {
public enum Variant: CaseIterable, Sendable {
case fastestV2
case fastestV2_1
case nvidiaLowLatencyV2
case nvidiaLowLatencyV2_1
case nvidiaHighLatencyV2
case nvidiaHighLatencyV2_1
case fastV2
case fastV2_1
case balancedV2
case balancedV2_1
case highContextV2
case highContextV2_1
public var name: String {
switch self {
case .fastestV2:
case .fastV2:
return "Sortformer_v2"
case .fastestV2_1:
case .fastV2_1:
return "Sortformer_v2.1"
case .nvidiaLowLatencyV2:
case .balancedV2:
return "SortformerNvidiaLow_v2"
case .nvidiaLowLatencyV2_1:
case .balancedV2_1:
return "SortformerNvidiaLow_v2.1"
case .nvidiaHighLatencyV2:
case .highContextV2:
return "SortformerNvidiaHigh_v2"
case .nvidiaHighLatencyV2_1:
case .highContextV2_1:
return "SortformerNvidiaHigh_v2.1"
}
}
public var defaultConfiguration: SortformerConfig {
switch self {
case .fastestV2:
return .fastestV2
case .fastestV2_1:
return .fastestV2_1
case .nvidiaLowLatencyV2:
return .nvidiaLowLatencyV2
case .nvidiaLowLatencyV2_1:
return .nvidiaLowLatencyV2_1
case .nvidiaHighLatencyV2:
return .nvidiaHighLatencyV2
case .nvidiaHighLatencyV2_1:
return .nvidiaHighLatencyV2_1
case .fastV2:
return .fastV2
case .fastV2_1:
return .fastV2_1
case .balancedV2:
return .balancedV2
case .balancedV2_1:
return .balancedV2_1
case .highContextV2:
return .highContextV2
case .highContextV2_1:
return .highContextV2_1
}
}
@@ -291,7 +291,7 @@ public enum ModelNames {
}
/// Lowest latency for streaming
public static let defaultVariant: Variant = .fastestV2_1
public static let defaultVariant: Variant = .fastV2_1
/// Bundle name for a specific variant
public static func bundle(for variant: Variant) -> String {
+38 -53
View File
@@ -1,8 +1,6 @@
import Accelerate
import CoreML
import Darwin
import Foundation
import Metal
/// Shared ANE optimization utilities for all ML pipelines
public enum ANEMemoryUtils {
@@ -120,13 +118,16 @@ public enum ANEMemoryUtils {
shape: [NSNumber],
strides: [NSNumber]? = nil
) throws -> MLMultiArray {
// Validate bounds
// Validate bounds using stride-aware backing size (accounts for ANE padding)
let elementSize = getElementSize(for: array.dataType)
let totalElements = shape.map { $0.intValue }.reduce(1, *)
let bytesNeeded = totalElements * elementSize
let viewStrides = strides ?? calculateOptimalStrides(for: shape)
let viewBackingElements =
shape.isEmpty ? 0 : viewStrides[0].intValue * shape[0].intValue
let sourceBackingElements =
array.shape.isEmpty ? 0 : array.strides[0].intValue * array.shape[0].intValue
let byteOffset = offset * elementSize
guard byteOffset + bytesNeeded <= array.count * elementSize else {
guard byteOffset + viewBackingElements * elementSize <= sourceBackingElements * elementSize else {
throw ANEMemoryError.invalidShape
}
@@ -144,34 +145,47 @@ public enum ANEMemoryUtils {
/// Stride-aware copy between two MLMultiArrays that may have different stride layouts.
///
/// Copies all logical elements from `source` to `destination` (which must have the same shape).
/// The innermost dimension is copied as a contiguous block (stride-1), while outer dimensions
/// Copies all logical elements from `source` to `destination` (which must have the same shape
/// and data type). The innermost dimension must have stride 1 in both arrays. Outer dimensions
/// are iterated respecting each array's strides.
public static func strideAwareCopy(from source: MLMultiArray, to destination: MLMultiArray) {
let shape = source.shape.map { $0.intValue }
let srcStrides = source.strides.map { $0.intValue }
let dstStrides = destination.strides.map { $0.intValue }
let srcPtr = source.dataPointer.assumingMemoryBound(to: Float.self)
let dstPtr = destination.dataPointer.assumingMemoryBound(to: Float.self)
let ndim = shape.count
guard ndim > 0 else { return }
// Validate shapes and types match
precondition(
source.shape == destination.shape,
"strideAwareCopy: shape mismatch \(source.shape) vs \(destination.shape)"
)
precondition(
source.dataType == destination.dataType,
"strideAwareCopy: dataType mismatch"
)
precondition(
srcStrides[ndim - 1] == 1 && dstStrides[ndim - 1] == 1,
"strideAwareCopy: innermost stride must be 1"
)
let elementSize = getElementSize(for: source.dataType)
let srcPtr = source.dataPointer
let dstPtr = destination.dataPointer
// If strides match, a single memcpy suffices (fast path).
if srcStrides == dstStrides {
// Total backing storage = first-dim stride * first-dim size
let totalBacking = srcStrides[0] * shape[0]
memcpy(dstPtr, srcPtr, totalBacking * MemoryLayout<Float>.size)
let totalBytes = srcStrides[0] * shape[0] * elementSize
memcpy(dstPtr, srcPtr, totalBytes)
return
}
// Innermost dimension length
let innerLen = shape[ndim - 1]
// Innermost dimension byte count
let innerBytes = shape[ndim - 1] * elementSize
if ndim == 1 {
// 1-D: just copy innerLen elements (both have stride 1 for innermost)
memcpy(dstPtr, srcPtr, innerLen * MemoryLayout<Float>.size)
memcpy(dstPtr, srcPtr, innerBytes)
return
}
@@ -182,16 +196,16 @@ public enum ANEMemoryUtils {
var indices = [Int](repeating: 0, count: ndim - 1)
for _ in 0..<outerCount {
// Compute flat offset for source and destination
var srcOffset = 0
var dstOffset = 0
// Compute flat byte offset for source and destination
var srcByteOffset = 0
var dstByteOffset = 0
for d in 0..<(ndim - 1) {
srcOffset += indices[d] * srcStrides[d]
dstOffset += indices[d] * dstStrides[d]
srcByteOffset += indices[d] * srcStrides[d] * elementSize
dstByteOffset += indices[d] * dstStrides[d] * elementSize
}
// Copy innermost dimension as contiguous block
memcpy(dstPtr + dstOffset, srcPtr + srcOffset, innerLen * MemoryLayout<Float>.size)
memcpy(dstPtr + dstByteOffset, srcPtr + srcByteOffset, innerBytes)
// Increment multi-index (odometer style)
var carry = ndim - 2
@@ -206,33 +220,4 @@ public enum ANEMemoryUtils {
}
}
/// Prefetch memory pages for ANE processing
public static func prefetchForANE(_ array: MLMultiArray) {
let dataPointer = array.dataPointer
let elementSize = getElementSize(for: array.dataType)
let totalBytes = array.count * elementSize
// Touch first and last cache lines to trigger ANE DMA prefetch
if totalBytes > 0 {
_ = dataPointer.load(as: UInt8.self)
if totalBytes > 1 {
_ = dataPointer.advanced(by: totalBytes - 1).load(as: UInt8.self)
}
}
}
}
/// Extension for MLMultiArray to add ANE optimization methods
extension MLMultiArray {
/// Check if this array is ANE-aligned
public var isANEAligned: Bool {
let address = Int(bitPattern: self.dataPointer)
return address % ANEMemoryUtils.aneAlignment == 0
}
/// Prefetch this array for ANE processing
public func prefetchForANE() {
ANEMemoryUtils.prefetchForANE(self)
}
}
+15 -9
View File
@@ -193,23 +193,25 @@ final public class AudioConverter {
throw AudioConverterError.failedToCreateSourceFormat
}
// Use AVAudioConverter for channel mixing and format conversion
// Use AVAudioConverter for channel mixing (same sample rate, no resampling needed)
guard let converter = AVAudioConverter(from: format, to: monoFormat) else {
throw AudioConverterError.failedToCreateConverter
}
configure(converter: converter)
guard let outputBuffer = AVAudioPCMBuffer(pcmFormat: monoFormat, frameCapacity: buffer.frameCapacity) else {
throw AudioConverterError.failedToCreateBuffer
}
nonisolated(unsafe) var provided = false
nonisolated(unsafe) let capturedBuffer = buffer
let provided = OSAllocatedUnfairLock(initialState: false)
let inputBlock: AVAudioConverterInputBlock = { _, status in
if !provided {
provided = true
let wasProvided = provided.withLock { state -> Bool in
if state { return true }
state = true
return false
}
if !wasProvided {
status.pointee = .haveData
return capturedBuffer
return buffer
} else {
status.pointee = .endOfStream
return nil
@@ -309,8 +311,12 @@ final public class AudioConverter {
// but Swift 6 rejects mutation of captured vars in this callback.
let provided = OSAllocatedUnfairLock(initialState: false)
let inputBlock: AVAudioConverterInputBlock = { _, status in
if !provided.withLock({ $0 }) {
provided.withLock { $0 = true }
let wasProvided = provided.withLock { state -> Bool in
if state { return true }
state = true
return false
}
if !wasProvided {
status.pointee = .haveData
return buffer
} else {
@@ -15,7 +15,7 @@ import Foundation
/// - center: True with pad_mode='constant' (zero padding)
/// - normalize: "NA" (no normalization)
/// - dither: 0.0 (disabled for determinism)
public final class NeMoMelSpectrogram {
public final class AudioMelSpectrogram {
public enum PaddingMode: Sendable {
case center
case prePadded
@@ -66,94 +66,11 @@ public enum PocketTtsResourceDownloader {
/// Download the Mimi encoder model files from HuggingFace.
private static func downloadMimiEncoder(to repoDir: URL) async throws {
let modelName = ModelNames.PocketTTS.mimiEncoderFile
let modelDir = repoDir.appendingPathComponent(modelName)
try FileManager.default.createDirectory(at: modelDir, withIntermediateDirectories: true)
// List files in the mimi_encoder.mlmodelc directory
let apiPath = "tree/main/\(modelName)"
let dirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, apiPath)
let (dirData, _) = try await DownloadUtils.fetchWithAuth(from: dirURL)
guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else {
throw PocketTTSError.downloadFailed("Failed to list Mimi encoder files")
}
// Collect all files recursively
var filesToDownload: [(path: String, size: Int)] = []
func collectFiles(from items: [[String: Any]], basePath: String) async throws {
for item in items {
guard let itemPath = item["path"] as? String,
let itemType = item["type"] as? String
else { continue }
if itemType == "directory" {
let subDirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, "tree/main/\(itemPath)")
let (subDirData, _) = try await DownloadUtils.fetchWithAuth(from: subDirURL)
if let subItems = try JSONSerialization.jsonObject(with: subDirData) as? [[String: Any]] {
try await collectFiles(from: subItems, basePath: itemPath)
}
} else if itemType == "file" {
let fileSize = item["size"] as? Int ?? -1
filesToDownload.append((path: itemPath, size: fileSize))
}
}
}
try await collectFiles(from: items, basePath: modelName)
logger.info("Found \(filesToDownload.count) files in Mimi encoder")
// Download each file
for (index, file) in filesToDownload.enumerated() {
// Local path relative to modelName
let relativePath =
file.path.hasPrefix("\(modelName)/")
? String(file.path.dropFirst(modelName.count + 1))
: file.path
let destPath = modelDir.appendingPathComponent(relativePath)
if FileManager.default.fileExists(atPath: destPath.path) {
continue
}
try FileManager.default.createDirectory(
at: destPath.deletingLastPathComponent(),
withIntermediateDirectories: true
)
// Handle empty files
if file.size == 0 {
FileManager.default.createFile(atPath: destPath.path, contents: Data())
continue
}
let encodedPath = file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
let fileURL = try ModelRegistry.resolveModel(Repo.pocketTts.remotePath, encodedPath)
let (tempURL, response) = try await DownloadUtils.sharedSession.download(
for: URLRequest(url: fileURL, timeoutInterval: 1800)
)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw PocketTTSError.downloadFailed("Failed to download \(file.path)")
}
if FileManager.default.fileExists(atPath: destPath.path) {
try? FileManager.default.removeItem(at: destPath)
}
try FileManager.default.moveItem(at: tempURL, to: destPath)
if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 {
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) Mimi encoder files")
}
}
logger.info("Mimi encoder download complete")
try await DownloadUtils.downloadSubdirectory(
.pocketTts,
subdirectory: ModelNames.PocketTTS.mimiEncoderFile,
to: repoDir
)
}
/// Ensure constants (binary blobs + tokenizer) are available.
@@ -180,9 +180,7 @@ enum StreamDiarizationBenchmark {
printUsage()
return
default:
if !arguments[i].starts(with: "--") {
logger.warning("Unknown argument: \(arguments[i])")
}
logger.warning("Unknown argument: \(arguments[i])")
}
i += 1
}
@@ -0,0 +1,332 @@
#if os(macOS)
import FluidAudio
import Foundation
/// Shared utilities for diarization benchmark commands (LS-EEND and Sortformer).
enum DiarizationBenchmarkUtils {
/// Dataset corpora supported by diarization benchmarks.
enum Dataset: String {
case ami = "ami"
case voxconverse = "voxconverse"
case callhome = "callhome"
}
/// Per-meeting benchmark result shared across diarization benchmark commands.
struct BenchmarkResult {
let meetingName: String
let der: Float
let missRate: Float
let falseAlarmRate: Float
let speakerErrorRate: Float
let rtfx: Float
let processingTime: Double
let totalFrames: Int
let detectedSpeakers: Int
let groundTruthSpeakers: Int
let modelLoadTime: Double
let audioLoadTime: Double?
}
// MARK: - File Paths
static func getAMIFiles(maxFiles: Int?) -> [String] {
let allMeetings = [
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
]
var availableMeetings: [String] = []
for meeting in allMeetings {
let path = getAudioPath(for: meeting, dataset: .ami)
if FileManager.default.fileExists(atPath: path) {
availableMeetings.append(meeting)
}
}
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
).path
case .voxconverse:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
).path
case .callhome:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
).path
}
}
static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/ami_official/rttm/\(meeting).rttm"
)
case .voxconverse:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm"
)
case .callhome:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm"
)
}
}
static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let voxDir = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
)
guard
let files = try? FileManager.default.contentsOfDirectory(
at: voxDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
let rttmPath = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
)
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
guard
let files = try? FileManager.default.contentsOfDirectory(
at: callhomeDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
/// Returns files for the given dataset, filtering by availability.
static func getFiles(for dataset: Dataset, maxFiles: Int?) -> [String] {
switch dataset {
case .ami:
return getAMIFiles(maxFiles: maxFiles)
case .voxconverse:
return getVoxConverseFiles(maxFiles: maxFiles)
case .callhome:
return getCALLHOMEFiles(maxFiles: maxFiles)
}
}
// MARK: - Summary & Output
/// Prints a formatted benchmark summary table.
///
/// - Parameters:
/// - results: Benchmark results to summarize.
/// - title: Header title (e.g. "LS-EEND BENCHMARK SUMMARY").
/// - derTargets: DER percentage thresholds to check, ordered from strictest to most lenient
/// (e.g. `[15, 25]` prints "DER < 15%" if met, else "DER < 25%", else "DER > 25%").
static func printFinalSummary(
results: [BenchmarkResult],
title: String,
derTargets: [Float]
) {
guard !results.isEmpty else { return }
print("\n" + String(repeating: "=", count: 80))
print(title)
print(String(repeating: "=", count: 80))
print("Results Sorted by DER:")
print(String(repeating: "-", count: 70))
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
print(String(repeating: "-", count: 70))
for result in results.sorted(by: { $0.der < $1.der }) {
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
print(
String(
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
meetingCol,
result.der,
result.missRate,
result.falseAlarmRate,
result.speakerErrorRate,
speakerCol,
result.rtfx))
}
print(String(repeating: "-", count: 70))
let count = Float(results.count)
let avgDER = results.map { $0.der }.reduce(0, +) / count
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
print(
String(
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
print(String(repeating: "=", count: 70))
print("\nTarget Check:")
var matched = false
for target in derTargets.sorted() {
if avgDER < target {
print(" DER < \(String(format: "%.0f", target))% (achieved: \(String(format: "%.1f", avgDER))%)")
matched = true
break
}
}
if !matched, let highest = derTargets.max() {
print(
" DER > \(String(format: "%.0f", highest))% (achieved: \(String(format: "%.1f", avgDER))%)")
}
if avgRTFx > 1 {
print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
} else {
print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
}
}
static func saveJSONResults(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { resultToDict($0) }
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
print("JSON results saved to: \(path)")
} catch {
print("Failed to save JSON: \(error)")
}
}
// MARK: - Progress Save/Load
static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
var dict: [String: Any] = [
"meeting": result.meetingName,
"der": result.der,
"missRate": result.missRate,
"falseAlarmRate": result.falseAlarmRate,
"speakerErrorRate": result.speakerErrorRate,
"rtfx": result.rtfx,
"processingTime": result.processingTime,
"totalFrames": result.totalFrames,
"detectedSpeakers": result.detectedSpeakers,
"groundTruthSpeakers": result.groundTruthSpeakers,
"modelLoadTime": result.modelLoadTime,
]
if let audioLoadTime = result.audioLoadTime {
dict["audioLoadTime"] = audioLoadTime
}
return dict
}
static func saveProgress(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { resultToDict($0) }
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
} catch {
print("Failed to save progress: \(error)")
}
}
static func loadProgress(from path: String) -> [BenchmarkResult]? {
guard FileManager.default.fileExists(atPath: path) else { return nil }
do {
let data = try Data(contentsOf: URL(fileURLWithPath: path))
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
return nil
}
return jsonArray.compactMap { dict -> BenchmarkResult? in
guard let meeting = dict["meeting"] as? String,
let der = (dict["der"] as? NSNumber)?.floatValue,
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue
else {
return nil
}
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
return BenchmarkResult(
meetingName: meeting,
der: der,
missRate: missRate,
falseAlarmRate: falseAlarmRate,
speakerErrorRate: speakerErrorRate,
rtfx: rtfx,
processingTime: processingTime,
totalFrames: totalFrames,
detectedSpeakers: detectedSpeakers,
groundTruthSpeakers: groundTruthSpeakers,
modelLoadTime: modelLoadTime,
audioLoadTime: audioLoadTime
)
}
} catch {
print("Failed to load progress: \(error)")
return nil
}
}
}
#endif
@@ -1,5 +1,4 @@
#if os(macOS)
import AVFoundation
import FluidAudio
import Foundation
@@ -7,26 +6,8 @@ import Foundation
enum LSEENDBenchmark {
private static let logger = AppLogger(category: "LSEENDBench")
enum Dataset: String {
case ami = "ami"
case voxconverse = "voxconverse"
case callhome = "callhome"
}
struct BenchmarkResult {
let meetingName: String
let der: Float
let missRate: Float
let falseAlarmRate: Float
let speakerErrorRate: Float
let rtfx: Float
let processingTime: Double
let totalFrames: Int
let detectedSpeakers: Int
let groundTruthSpeakers: Int
let modelLoadTime: Double
let audioLoadTime: Double
}
typealias Dataset = DiarizationBenchmarkUtils.Dataset
typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult
static func printUsage() {
print(
@@ -197,9 +178,7 @@ enum LSEENDBenchmark {
printUsage()
return
default:
if !arguments[i].starts(with: "--") {
logger.warning("Unknown argument: \(arguments[i])")
}
logger.warning("Unknown argument: \(arguments[i])")
}
i += 1
}
@@ -228,14 +207,7 @@ enum LSEENDBenchmark {
if let meeting = singleFile {
filesToProcess = [meeting]
} else {
switch dataset {
case .ami:
filesToProcess = getAMIFiles(maxFiles: maxFiles)
case .voxconverse:
filesToProcess = getVoxConverseFiles(maxFiles: maxFiles)
case .callhome:
filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles)
}
filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles)
}
if filesToProcess.isEmpty {
@@ -252,7 +224,7 @@ enum LSEENDBenchmark {
var completedResults: [BenchmarkResult] = []
var completedMeetings: Set<String> = []
if resumeFromProgress {
if let loaded = loadProgress(from: progressFile) {
if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) {
completedResults = loaded
completedMeetings = Set(loaded.map { $0.meetingName })
print("Resuming: loaded \(completedResults.count) previous results")
@@ -339,7 +311,7 @@ enum LSEENDBenchmark {
print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth")
// Save progress after each file
saveProgress(results: allResults, to: progressFile)
DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile)
print("Progress saved (\(allResults.count) files complete)")
}
fflush(stdout)
@@ -349,11 +321,15 @@ enum LSEENDBenchmark {
}
// Print final summary
printFinalSummary(results: allResults)
DiarizationBenchmarkUtils.printFinalSummary(
results: allResults,
title: "LS-EEND BENCHMARK SUMMARY",
derTargets: [15, 25]
)
// Save results
if let outputPath = outputFile {
saveJSONResults(results: allResults, to: outputPath)
DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath)
}
}
@@ -369,7 +345,7 @@ enum LSEENDBenchmark {
numSpeakers: Int,
verbose: Bool
) async -> BenchmarkResult? {
let audioPath = getAudioPath(for: meetingName, dataset: dataset)
let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset)
guard FileManager.default.fileExists(atPath: audioPath) else {
print("Audio file not found: \(audioPath)")
fflush(stdout)
@@ -378,10 +354,7 @@ enum LSEENDBenchmark {
do {
// Load and process audio
let startLoadTime = Date()
let audioURL = URL(fileURLWithPath: audioPath)
let audioLoadTime = Date().timeIntervalSince(startLoadTime)
let startTime = Date()
let timeline = try diarizer.processComplete(audioFileURL: audioURL)
let processingTime = Date().timeIntervalSince(startTime)
@@ -400,7 +373,7 @@ enum LSEENDBenchmark {
let rttmEntries: [LSEENDRTTMEntry]
let rttmSpeakers: [String]
let rttmURL = getRTTMURL(for: meetingName, dataset: dataset)
let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset)
if let rttmURL = rttmURL, FileManager.default.fileExists(atPath: rttmURL.path) {
let parsed = try LSEENDEvaluation.parseRTTM(url: rttmURL)
rttmEntries = parsed.entries
@@ -413,7 +386,7 @@ enum LSEENDBenchmark {
duration: duration
)
guard !groundTruth.isEmpty else {
print("⚠️ No ground truth found for \(meetingName)")
print("No ground truth found for \(meetingName)")
return nil
}
// Convert TimedSpeakerSegment to LSEENDRTTMEntry
@@ -507,7 +480,7 @@ enum LSEENDBenchmark {
detectedSpeakers: detectedSpeakerIndices.count,
groundTruthSpeakers: rttmSpeakers.count,
modelLoadTime: modelLoadTime,
audioLoadTime: audioLoadTime
audioLoadTime: nil
)
} catch {
@@ -516,279 +489,5 @@ enum LSEENDBenchmark {
}
}
// MARK: - File Paths
private static func getAMIFiles(maxFiles: Int?) -> [String] {
let allMeetings = [
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
]
var availableMeetings: [String] = []
for meeting in allMeetings {
let path = getAudioPath(for: meeting, dataset: .ami)
if FileManager.default.fileExists(atPath: path) {
availableMeetings.append(meeting)
}
}
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
private static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
).path
case .voxconverse:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
).path
case .callhome:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
).path
}
}
private static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
// Try local RTTM first, then fall back to dataset directory
let localPath = "Streaming-Sortformer-Conversion/\(meeting).rttm"
if FileManager.default.fileExists(atPath: localPath) {
return URL(fileURLWithPath: localPath)
}
let datasetPath = homeDir.appendingPathComponent(
"FluidAudioDatasets/ami_official/rttm/\(meeting).rttm"
)
return datasetPath
case .voxconverse:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm"
)
case .callhome:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm"
)
}
}
private static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let voxDir = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
)
guard
let files = try? FileManager.default.contentsOfDirectory(
at: voxDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
let rttmPath = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
)
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
guard
let files = try? FileManager.default.contentsOfDirectory(
at: callhomeDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
// MARK: - Summary & Output
private static func printFinalSummary(results: [BenchmarkResult]) {
guard !results.isEmpty else { return }
print("\n" + String(repeating: "=", count: 80))
print("LS-EEND BENCHMARK SUMMARY")
print(String(repeating: "=", count: 80))
print("Results Sorted by DER:")
print(String(repeating: "-", count: 70))
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
print(String(repeating: "-", count: 70))
for result in results.sorted(by: { $0.der < $1.der }) {
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
print(
String(
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
meetingCol,
result.der,
result.missRate,
result.falseAlarmRate,
result.speakerErrorRate,
speakerCol,
result.rtfx))
}
print(String(repeating: "-", count: 70))
let count = Float(results.count)
let avgDER = results.map { $0.der }.reduce(0, +) / count
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
print(
String(
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
print(String(repeating: "=", count: 70))
print("\nTarget Check:")
if avgDER < 15 {
print(" DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)")
} else if avgDER < 25 {
print(" DER < 25% (achieved: \(String(format: "%.1f", avgDER))%)")
} else {
print(" DER > 25% (achieved: \(String(format: "%.1f", avgDER))%)")
}
if avgRTFx > 1 {
print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
} else {
print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
}
}
private static func saveJSONResults(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { resultToDict($0) }
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
print("JSON results saved to: \(path)")
} catch {
print("Failed to save JSON: \(error)")
}
}
// MARK: - Progress Save/Load
private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
return [
"meeting": result.meetingName,
"der": result.der,
"missRate": result.missRate,
"falseAlarmRate": result.falseAlarmRate,
"speakerErrorRate": result.speakerErrorRate,
"rtfx": result.rtfx,
"processingTime": result.processingTime,
"totalFrames": result.totalFrames,
"detectedSpeakers": result.detectedSpeakers,
"groundTruthSpeakers": result.groundTruthSpeakers,
"modelLoadTime": result.modelLoadTime,
"audioLoadTime": result.audioLoadTime,
]
}
private static func saveProgress(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { resultToDict($0) }
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
} catch {
print("Failed to save progress: \(error)")
}
}
private static func loadProgress(from path: String) -> [BenchmarkResult]? {
guard FileManager.default.fileExists(atPath: path) else { return nil }
do {
let data = try Data(contentsOf: URL(fileURLWithPath: path))
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
return nil
}
return jsonArray.compactMap { dict -> BenchmarkResult? in
guard let meeting = dict["meeting"] as? String,
let der = (dict["der"] as? NSNumber)?.floatValue,
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue,
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
else {
return nil
}
return BenchmarkResult(
meetingName: meeting,
der: der,
missRate: missRate,
falseAlarmRate: falseAlarmRate,
speakerErrorRate: speakerErrorRate,
rtfx: rtfx,
processingTime: processingTime,
totalFrames: totalFrames,
detectedSpeakers: detectedSpeakers,
groundTruthSpeakers: groundTruthSpeakers,
modelLoadTime: modelLoadTime,
audioLoadTime: audioLoadTime
)
}
} catch {
print("Failed to load progress: \(error)")
return nil
}
}
}
#endif
@@ -1,5 +1,4 @@
#if os(macOS)
import AVFoundation
import FluidAudio
import Foundation
@@ -20,8 +19,6 @@ enum LSEENDCommand {
var outputFile: String?
var variant: LSEENDVariant = .dihard3
var threshold: Float = 0.5
var medianWidth: Int = 1
var collarSeconds: Double = 0.25
// Post-processing parameters
var onset: Float?
@@ -62,16 +59,6 @@ enum LSEENDCommand {
threshold = v
i += 1
}
case "--median-width":
if i + 1 < arguments.count, let v = Int(arguments[i + 1]) {
medianWidth = v
i += 1
}
case "--collar":
if i + 1 < arguments.count, let v = Double(arguments[i + 1]) {
collarSeconds = v
i += 1
}
case "--onset":
if i + 1 < arguments.count, let v = Float(arguments[i + 1]) {
onset = v
@@ -244,7 +231,7 @@ enum LSEENDCommand {
}
private static func printUsage() {
logger.info(
print(
"""
LS-EEND Command Usage:
@@ -253,8 +240,6 @@ enum LSEENDCommand {
Options:
--variant <name> Model variant: ami, callhome, dihard2, dihard3 (default: dihard3)
--threshold <value> Speaker activity threshold (default: 0.5)
--median-width <value> Median filter width for post-processing (default: 1)
--collar <value> Collar duration in seconds for evaluation (default: 0.25)
--onset <value> Onset threshold for speech detection (default: 0.5)
--offset <value> Offset threshold for speech detection (default: 0.5)
--pad-onset <value> Padding before speech segments in seconds
@@ -273,8 +258,7 @@ enum LSEENDCommand {
# Save results to file
fluidaudio lseend audio.wav --output results.json
"""
)
""")
}
}
#endif
@@ -85,7 +85,7 @@ public enum LSEENDEvaluation {
var speakers: [String] = []
for line in text.split(whereSeparator: \.isNewline) {
let parts = line.split(separator: " ")
guard parts.count >= 8 else { continue }
guard parts.count >= 8, parts[0] == "SPEAKER" else { continue }
let speaker = String(parts[7])
if !speakers.contains(speaker) {
speakers.append(speaker)
@@ -444,6 +444,7 @@ public enum LSEENDEvaluation {
}
private static func solveAssignmentRowsToColumns(cost: [Float], rows: Int, columns: Int) -> [Int] {
precondition(columns <= 20, "Assignment solver is O(2^columns); columns=\(columns) is too large")
let stateCount = 1 << columns
var dp = [Float](repeating: .greatestFiniteMagnitude, count: stateCount)
var parent = [Int](repeating: -1, count: stateCount)
@@ -1,5 +1,4 @@
#if os(macOS)
import AVFoundation
import FluidAudio
import Foundation
@@ -7,26 +6,8 @@ import Foundation
enum SortformerBenchmark {
private static let logger = AppLogger(category: "SortformerBench")
enum Dataset: String {
case ami = "ami"
case voxconverse = "voxconverse"
case callhome = "callhome"
}
struct BenchmarkResult {
let meetingName: String
let der: Float
let missRate: Float
let falseAlarmRate: Float
let speakerErrorRate: Float
let rtfx: Float
let processingTime: Double
let totalFrames: Int
let detectedSpeakers: Int
let groundTruthSpeakers: Int
let modelLoadTime: Double
let audioLoadTime: Double
}
typealias Dataset = DiarizationBenchmarkUtils.Dataset
typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult
static func printUsage() {
print(
@@ -42,7 +23,6 @@ enum SortformerBenchmark {
--single-file <name> Process a specific meeting (e.g., ES2004a)
--max-files <n> Maximum number of files to process
--threshold <value> Speaker activity threshold (default: 0.5)
--preprocessor <path> Path to SortformerPreprocessor.mlpackage
--model <path> Path to Sortformer.mlpackage
--nvidia-low-latency Use NVIDIA 1.04s latency config (20.57% DER target)
--nvidia-high-latency Use NVIDIA 30.4s latency config (20.57% DER target)
@@ -102,7 +82,7 @@ enum SortformerBenchmark {
if let d = Dataset(rawValue: arguments[i + 1].lowercased()) {
dataset = d
} else {
print("⚠️ Unknown dataset: \(arguments[i + 1]). Using ami.")
print("Unknown dataset: \(arguments[i + 1]). Using ami.")
}
i += 1
}
@@ -158,9 +138,7 @@ enum SortformerBenchmark {
printUsage()
return
default:
if !arguments[i].starts(with: "--") {
logger.warning("Unknown argument: \(arguments[i])")
}
logger.warning("Unknown argument: \(arguments[i])")
}
i += 1
}
@@ -170,7 +148,7 @@ enum SortformerBenchmark {
useHuggingFace = true
}
print("🚀 Starting Sortformer Benchmark")
print("Starting Sortformer Benchmark")
fflush(stdout)
print(" Dataset: \(dataset.rawValue)")
print(" Threshold: \(threshold)")
@@ -209,7 +187,7 @@ enum SortformerBenchmark {
// Download dataset if needed
if autoDownload && dataset == .ami {
print("📥 Downloading AMI dataset if needed...")
print("Downloading AMI dataset if needed...")
await DatasetDownloader.downloadAMIDataset(
variant: .sdm,
force: false,
@@ -223,23 +201,16 @@ enum SortformerBenchmark {
if let meeting = singleFile {
filesToProcess = [meeting]
} else {
switch dataset {
case .ami:
filesToProcess = getAMIFiles(maxFiles: maxFiles)
case .voxconverse:
filesToProcess = getVoxConverseFiles(maxFiles: maxFiles)
case .callhome:
filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles)
}
filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles)
}
if filesToProcess.isEmpty {
print("No files found to process")
print("No files found to process")
fflush(stdout)
return
}
print("📂 Processing \(filesToProcess.count) file(s)")
print("Processing \(filesToProcess.count) file(s)")
print(" Progress file: \(progressFile)")
fflush(stdout)
@@ -247,29 +218,29 @@ enum SortformerBenchmark {
var completedResults: [BenchmarkResult] = []
var completedMeetings: Set<String> = []
if resumeFromProgress {
if let loaded = loadProgress(from: progressFile) {
if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) {
completedResults = loaded
completedMeetings = Set(loaded.map { $0.meetingName })
print("📥 Resuming: loaded \(completedResults.count) previous results")
print("Resuming: loaded \(completedResults.count) previous results")
for result in completedResults {
print(" \(result.meetingName): \(String(format: "%.1f", result.der))% DER")
print(" \(result.meetingName): \(String(format: "%.1f", result.der))% DER")
}
} else {
print("📥 No previous progress found, starting fresh")
print("No previous progress found, starting fresh")
}
}
print("")
fflush(stdout)
// Initialize Sortformer
print("🔧 Loading Sortformer models...")
print("Loading Sortformer models...")
fflush(stdout)
let modelLoadStart = Date()
var config: SortformerConfig
if useNvidiaHighLatency {
config = SortformerConfig.nvidiaHighLatencyV2_1
config = SortformerConfig.highContextV2_1
} else if useNvidiaLowLatency {
config = SortformerConfig.nvidiaLowLatencyV2_1
config = SortformerConfig.balancedV2_1
} else {
config = SortformerConfig.default
}
@@ -287,12 +258,12 @@ enum SortformerBenchmark {
)
}
} catch {
print("Failed to initialize Sortformer: \(error)")
print("Failed to initialize Sortformer: \(error)")
return
}
let modelLoadTime = Date().timeIntervalSince(modelLoadStart)
print("Models loaded in \(String(format: "%.2f", modelLoadTime))s\n")
print("Models loaded in \(String(format: "%.2f", modelLoadTime))s\n")
fflush(stdout)
// Process each file
@@ -324,14 +295,14 @@ enum SortformerBenchmark {
allResults.append(result)
// Print summary
print("📊 Results for \(meetingName):")
print("Results for \(meetingName):")
print(" DER: \(String(format: "%.1f", result.der))%")
print(" RTFx: \(String(format: "%.1f", result.rtfx))x")
print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth")
// Save progress after each file
saveProgress(results: allResults, to: progressFile)
print("💾 Progress saved (\(allResults.count) files complete)")
DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile)
print("Progress saved (\(allResults.count) files complete)")
}
fflush(stdout)
@@ -340,11 +311,15 @@ enum SortformerBenchmark {
}
// Print final summary
printFinalSummary(results: allResults)
DiarizationBenchmarkUtils.printFinalSummary(
results: allResults,
title: "SORTFORMER BENCHMARK SUMMARY",
derTargets: [15, 20]
)
// Save results
if let outputPath = outputFile {
saveJSONResults(results: allResults, to: outputPath)
DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath)
}
}
@@ -357,9 +332,9 @@ enum SortformerBenchmark {
verbose: Bool
) async -> BenchmarkResult? {
let audioPath = getAudioPath(for: meetingName, dataset: dataset)
let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset)
guard FileManager.default.fileExists(atPath: audioPath) else {
print("Audio file not found: \(audioPath)")
print("Audio file not found: \(audioPath)")
fflush(stdout)
return nil
}
@@ -441,7 +416,7 @@ enum SortformerBenchmark {
}
guard !groundTruth.isEmpty else {
print("⚠️ No ground truth found for \(meetingName)")
print("No ground truth found for \(meetingName)")
return nil
}
@@ -453,7 +428,7 @@ enum SortformerBenchmark {
let simpleMetrics = calculateSimpleDER(
predictions: filteredPredictions,
numFrames: result.numFinalizedFrames,
numSpeakers: 4,
numSpeakers: result.config.numSpeakers,
groundTruth: groundTruth,
threshold: threshold,
frameShift: 0.08 // 80ms frames like NeMo
@@ -490,267 +465,7 @@ enum SortformerBenchmark {
)
} catch {
print("Error processing \(meetingName): \(error)")
return nil
}
}
private static func getAMIFiles(maxFiles: Int?) -> [String] {
// Official AMI SDM test set (16 meetings) - matches NeMo evaluation
let allMeetings = [
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
]
var availableMeetings: [String] = []
for meeting in allMeetings {
let path = getAudioPath(for: meeting, dataset: .ami)
if FileManager.default.fileExists(atPath: path) {
availableMeetings.append(meeting)
}
}
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
private static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
).path
case .voxconverse:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
).path
case .callhome:
return homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
).path
}
}
private static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let voxDir = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
)
guard
let files = try? FileManager.default.contentsOfDirectory(
at: voxDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
// Check that RTTM file exists
let rttmPath = homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
)
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
// Sort alphabetically for reproducibility
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
let homeDir = FileManager.default.homeDirectoryForCurrentUser
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
guard
let files = try? FileManager.default.contentsOfDirectory(
at: callhomeDir,
includingPropertiesForKeys: nil
)
else {
return []
}
var availableMeetings: [String] = []
for file in files where file.pathExtension == "wav" {
let name = file.deletingPathExtension().lastPathComponent
// Check that RTTM file exists
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
if FileManager.default.fileExists(atPath: rttmPath.path) {
availableMeetings.append(name)
}
}
// Sort alphabetically for reproducibility
availableMeetings.sort()
if let max = maxFiles {
return Array(availableMeetings.prefix(max))
}
return availableMeetings
}
private static func printFinalSummary(results: [BenchmarkResult]) {
guard !results.isEmpty else { return }
print("\n" + String(repeating: "=", count: 80))
print("SORTFORMER BENCHMARK SUMMARY")
print(String(repeating: "=", count: 80))
print("📋 Results Sorted by DER:")
print(String(repeating: "-", count: 70))
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
print(String(repeating: "-", count: 70))
for result in results.sorted(by: { $0.der < $1.der }) {
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
print(
String(
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
meetingCol,
result.der,
result.missRate,
result.falseAlarmRate,
result.speakerErrorRate,
speakerCol,
result.rtfx))
}
print(String(repeating: "-", count: 70))
let count = Float(results.count)
let avgDER = results.map { $0.der }.reduce(0, +) / count
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
print(
String(
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
print(String(repeating: "=", count: 70))
print("\n✅ Target Check:")
if avgDER < 15 {
print(" ✅ DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)")
} else if avgDER < 20 {
print(" 🟡 DER < 20% (achieved: \(String(format: "%.1f", avgDER))%)")
} else {
print(" ❌ DER > 20% (achieved: \(String(format: "%.1f", avgDER))%)")
}
if avgRTFx > 1 {
print(" ✅ RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
} else {
print(" ❌ RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
}
}
private static func saveJSONResults(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { result in
resultToDict(result)
}
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
print("💾 JSON results saved to: \(path)")
} catch {
print("❌ Failed to save JSON: \(error)")
}
}
// MARK: - Progress Save/Load
private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
return [
"meeting": result.meetingName,
"der": result.der,
"missRate": result.missRate,
"falseAlarmRate": result.falseAlarmRate,
"speakerErrorRate": result.speakerErrorRate,
"rtfx": result.rtfx,
"processingTime": result.processingTime,
"totalFrames": result.totalFrames,
"detectedSpeakers": result.detectedSpeakers,
"groundTruthSpeakers": result.groundTruthSpeakers,
"modelLoadTime": result.modelLoadTime,
"audioLoadTime": result.audioLoadTime,
]
}
private static func saveProgress(results: [BenchmarkResult], to path: String) {
let jsonData = results.map { resultToDict($0) }
do {
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
try data.write(to: URL(fileURLWithPath: path))
} catch {
print("⚠️ Failed to save progress: \(error)")
}
}
private static func loadProgress(from path: String) -> [BenchmarkResult]? {
guard FileManager.default.fileExists(atPath: path) else { return nil }
do {
let data = try Data(contentsOf: URL(fileURLWithPath: path))
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
return nil
}
return jsonArray.compactMap { dict -> BenchmarkResult? in
guard let meeting = dict["meeting"] as? String,
let der = (dict["der"] as? NSNumber)?.floatValue,
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue,
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
else {
return nil
}
return BenchmarkResult(
meetingName: meeting,
der: der,
missRate: missRate,
falseAlarmRate: falseAlarmRate,
speakerErrorRate: speakerErrorRate,
rtfx: rtfx,
processingTime: processingTime,
totalFrames: totalFrames,
detectedSpeakers: detectedSpeakers,
groundTruthSpeakers: groundTruthSpeakers,
modelLoadTime: modelLoadTime,
audioLoadTime: audioLoadTime
)
}
} catch {
print("⚠️ Failed to load progress: \(error)")
print("Error processing \(meetingName): \(error)")
return nil
}
}
@@ -760,23 +475,11 @@ enum SortformerBenchmark {
/// Load ground truth from RTTM file like Python does
/// Format: SPEAKER <meeting_id> 1 <start_time> <duration> <NA> <NA> <speaker_id> <NA> <NA>
private static func loadRTTMGroundTruth(for meetingName: String, dataset: Dataset) -> [TimedSpeakerSegment] {
// Determine RTTM path based on dataset
let rttmPath: String
let homeDir = FileManager.default.homeDirectoryForCurrentUser
switch dataset {
case .ami:
rttmPath = "Streaming-Sortformer-Conversion/\(meetingName).rttm"
case .voxconverse:
rttmPath =
homeDir.appendingPathComponent(
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meetingName).rttm"
).path
case .callhome:
rttmPath =
homeDir.appendingPathComponent(
"FluidAudioDatasets/callhome_eng/rttm/\(meetingName).rttm"
).path
guard let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset) else {
print(" [RTTM] No RTTM URL for \(meetingName)")
return []
}
let rttmPath = rttmURL.path
guard FileManager.default.fileExists(atPath: rttmPath) else {
print(" [RTTM] File not found: \(rttmPath)")
@@ -1,5 +1,4 @@
#if os(macOS)
import AVFoundation
import FluidAudio
import Foundation
@@ -131,9 +130,10 @@ enum SortformerCommand {
print(
"[DEBUG] First 10 audio samples: \((0..<min(10, audioSamples.count)).map { String(format: "%.6f", audioSamples[$0]) }.joined(separator: ", "))"
)
let debugPath = NSTemporaryDirectory() + "swift_audio_16k.bin"
let audioData = audioSamples.withUnsafeBytes { Data($0) }
try? audioData.write(to: URL(fileURLWithPath: "swift_audio_16k.bin"))
print("[DEBUG] Saved \(audioSamples.count) samples to swift_audio_16k.bin")
try? audioData.write(to: URL(fileURLWithPath: debugPath))
print("[DEBUG] Saved \(audioSamples.count) samples to \(debugPath)")
}
// Process with progress
@@ -178,12 +178,13 @@ enum SortformerCommand {
// Print speaker probabilities summary
print("\n--- Speaker Activity Summary ---")
let numSpeakers = 4
let numSpeakers = result.config.numSpeakers
var speakerActivity = [Float](repeating: 0, count: numSpeakers)
let predictions = result.finalizedPredictions
for frame in 0..<result.numFinalizedFrames {
for spk in 0..<numSpeakers {
let prob = result.finalizedPredictions[frame * numSpeakers + spk]
if prob > 0.5 {
let idx = frame * numSpeakers + spk
if idx < predictions.count, predictions[idx] > 0.5 {
speakerActivity[spk] += result.config.frameDurationSeconds
}
}
@@ -233,7 +234,7 @@ enum SortformerCommand {
}
private static func printUsage() {
logger.info(
print(
"""
Sortformer Command Usage:
@@ -259,8 +260,7 @@ enum SortformerCommand {
# Save results to file
fluidaudio sortformer audio.wav --output results.json
"""
)
""")
}
}
#endif
@@ -3,13 +3,13 @@ import XCTest
@testable import FluidAudio
final class NeMoMelSpectrogramTests: XCTestCase {
final class AudioMelSpectrogramTests: XCTestCase {
private var mel: NeMoMelSpectrogram!
private var mel: AudioMelSpectrogram!
override func setUp() {
super.setUp()
mel = NeMoMelSpectrogram()
mel = AudioMelSpectrogram()
}
override func tearDown() {
@@ -0,0 +1,136 @@
import AVFoundation
import Foundation
import XCTest
@testable import FluidAudio
/// Shared fixture infrastructure for diarization tests.
///
/// Generates a deterministic multi-segment waveform with silence gaps, writes it to a
/// temporary WAV file, and caches it for reuse across tests within the same process.
enum DiarizationTestFixtures {
static let fixtureSampleRate = 16_000
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
/// Returns a cached URL to the fixture WAV file, creating it on first access.
static func fixtureAudioFileURL() throws -> URL {
if let cached = cachedFixtureAudioURL,
FileManager.default.fileExists(atPath: cached.path)
{
return cached
}
let url = FileManager.default.temporaryDirectory
.appendingPathComponent("diarization-fixture-\(UUID().uuidString)")
.appendingPathExtension("wav")
try writeFixtureAudio(to: url)
cachedFixtureAudioURL = url
return url
}
/// Loads fixture audio resampled to the given sample rate, optionally limited to a duration.
static func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] {
let converter = AudioConverter(sampleRate: Double(sampleRate))
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
guard let limitSeconds else {
return audio
}
let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate)))
return Array(audio.prefix(sampleCount))
}
/// Loads a slice of fixture audio at the given sample rate.
static func fixtureAudio(
sampleRate: Int, startSeconds: Double, durationSeconds: Double
) throws -> [Float] {
let converter = AudioConverter(sampleRate: Double(sampleRate))
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
let startSample = min(audio.count, Int(startSeconds * Double(sampleRate)))
let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate)))
return Array(audio[startSample..<endSample])
}
/// Splits samples into chunks with rotating sizes.
static func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
var chunks: [[Float]] = []
var start = 0
var index = 0
while start < samples.count {
let size = sizes[index % sizes.count]
let stop = min(samples.count, start + size)
chunks.append(Array(samples[start..<stop]))
start = stop
index += 1
}
return chunks
}
// MARK: - Private
private static func writeFixtureAudio(to url: URL) throws {
let sampleRate = Double(fixtureSampleRate)
let samples = makeFixtureSamples(sampleRate: sampleRate)
let format = AVAudioFormat(
commonFormat: .pcmFormatFloat32,
sampleRate: sampleRate,
channels: 1,
interleaved: false
)!
guard
let buffer = AVAudioPCMBuffer(
pcmFormat: format,
frameCapacity: AVAudioFrameCount(samples.count)
)
else {
XCTFail("Failed to allocate fixture audio buffer")
return
}
buffer.frameLength = AVAudioFrameCount(samples.count)
samples.withUnsafeBufferPointer { source in
guard let destination = buffer.floatChannelData?[0] else { return }
destination.update(from: source.baseAddress!, count: samples.count)
}
let file = try AVAudioFile(
forWriting: url,
settings: format.settings,
commonFormat: .pcmFormatFloat32,
interleaved: false
)
try file.write(from: buffer)
}
private static func makeFixtureSamples(sampleRate: Double) -> [Float] {
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
(1.0, 0.20, 220),
(0.35, 0.00, 0),
(1.1, 0.32, 330),
(0.25, 0.00, 0),
(1.0, 0.28, 180),
(0.40, 0.00, 0),
(1.3, 0.36, 260),
(0.30, 0.00, 0),
(1.1, 0.24, 410),
]
var output: [Float] = []
for (duration, amplitude, frequency) in segments {
let frameCount = Int(duration * sampleRate)
guard amplitude > 0, frequency > 0 else {
output.append(contentsOf: repeatElement(0, count: frameCount))
continue
}
for frame in 0..<frameCount {
let time = Double(frame) / sampleRate
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
let carrier = sin(2.0 * Double.pi * frequency * time)
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
}
}
return output
}
}
@@ -1,4 +1,3 @@
import AVFoundation
import CoreML
import Foundation
import XCTest
@@ -11,8 +10,6 @@ final class LSEENDIntegrationTests: XCTestCase {
let meanAbs: Double
}
private static let fixtureSampleRate = 16_000
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
nonisolated(unsafe) private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:]
func testVariantRegistryResolvesAllExportedArtifacts() async throws {
@@ -39,7 +36,8 @@ final class LSEENDIntegrationTests: XCTestCase {
func testOfflineInferenceProducesConsistentShapesAcrossVariants() async throws {
for variant in LSEENDVariant.allCases {
let engine = try await makeEngine(variant: variant)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 2.0)
let samples = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, limitSeconds: 2.0)
let result = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
try assertResultInvariants(
@@ -53,8 +51,8 @@ final class LSEENDIntegrationTests: XCTestCase {
func testAudioFileInferenceMatchesInferenceOnResampledFixtureSamples() async throws {
let engine = try await makeEngine(variant: .dihard3)
let fileResult = try engine.infer(audioFileURL: try fixtureAudioFileURL())
let resampled = try fixtureAudio(sampleRate: engine.targetSampleRate)
let fileResult = try engine.infer(audioFileURL: try DiarizationTestFixtures.fixtureAudioFileURL())
let resampled = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate)
let sampleResult = try engine.infer(samples: resampled, sampleRate: engine.targetSampleRate)
assertMatrixClose(fileResult.logits, sampleResult.logits, maxAbs: 1e-6, meanAbs: 1e-7)
@@ -65,7 +63,7 @@ final class LSEENDIntegrationTests: XCTestCase {
func testStreamingSessionMatchesOfflineInferenceOnRealFixtureAudio() async throws {
let engine = try await makeEngine(variant: .dihard3)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let offline = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate)
@@ -115,7 +113,7 @@ final class LSEENDIntegrationTests: XCTestCase {
func testStreamingSimulationMatchesOfflineInferenceAndReportsMonotonicProgress() async throws {
let engine = try await makeEngine(variant: .dihard3)
let fixtureURL = try fixtureAudioFileURL()
let fixtureURL = try DiarizationTestFixtures.fixtureAudioFileURL()
let offline = try engine.infer(audioFileURL: fixtureURL)
let simulation = try engine.simulateStreaming(audioFileURL: fixtureURL, chunkSeconds: 0.37)
@@ -142,7 +140,7 @@ final class LSEENDIntegrationTests: XCTestCase {
func testDiarizerProcessCompleteMatchesEngineInference() async throws {
let engine = try await makeEngine(variant: .dihard3)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let expected = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
@@ -159,13 +157,13 @@ final class LSEENDIntegrationTests: XCTestCase {
func testDiarizerStreamingFinalizeMatchesProcessComplete() async throws {
let engine = try await makeEngine(variant: .dihard3)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
let expected = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
for chunk in chunk(samples, sizes: [701, 977, 1153]) {
for chunk in DiarizationTestFixtures.chunk(samples, sizes: [701, 977, 1153]) {
let _ = try diarizer.process(samples: chunk)
}
let _ = try diarizer.finalizeSession()
@@ -195,7 +193,7 @@ final class LSEENDIntegrationTests: XCTestCase {
func testEnrollSpeakerResetsVisibleTimelineAndAllowsStreaming() async throws {
let engine = try await makeEngine(variant: .dihard3)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
let enrollmentCount = min(samples.count / 2, engine.targetSampleRate * 2)
let enrollment = Array(samples.prefix(enrollmentCount))
let live = Array(samples.dropFirst(enrollmentCount))
@@ -212,7 +210,7 @@ final class LSEENDIntegrationTests: XCTestCase {
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
var firstUpdate: DiarizerTimelineUpdate?
for chunk in chunk(live, sizes: [977, 1231, 1607]) {
for chunk in DiarizationTestFixtures.chunk(live, sizes: [977, 1231, 1607]) {
if let update = try diarizer.process(samples: chunk) {
firstUpdate = update
break
@@ -231,7 +229,7 @@ final class LSEENDIntegrationTests: XCTestCase {
func testProcessCompleteKeepsPrimedSessionOnlyWhenRequested() async throws {
let engine = try await makeEngine(variant: .dihard3)
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
let enrollmentSampleCount = engine.targetSampleRate * 2
let enrollment = Array(samples.prefix(enrollmentSampleCount))
let complete = Array(samples.dropFirst(enrollmentSampleCount).prefix(enrollmentSampleCount))
@@ -240,13 +238,13 @@ final class LSEENDIntegrationTests: XCTestCase {
diarizer.initialize(engine: engine)
_ = try diarizer.enrollSpeaker(withSamples: enrollment, named: "Alice")
XCTAssertTrue(hasActiveSession(diarizer))
XCTAssertTrue(diarizer.hasActiveSession)
_ = try diarizer.processComplete(complete, keepingEnrolledSpeakers: true)
XCTAssertFalse(hasActiveSession(diarizer))
XCTAssertFalse(diarizer.hasActiveSession)
_ = try diarizer.processComplete(complete, keepingEnrolledSpeakers: false)
XCTAssertFalse(hasActiveSession(diarizer))
XCTAssertFalse(diarizer.hasActiveSession)
}
private func makeEngine(variant: LSEENDVariant) async throws -> LSEENDInferenceHelper {
@@ -259,115 +257,10 @@ final class LSEENDIntegrationTests: XCTestCase {
return engine
}
private func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] {
let converter = AudioConverter(sampleRate: Double(sampleRate))
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
guard let limitSeconds else {
return audio
}
let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate)))
return Array(audio.prefix(sampleCount))
}
private func fixtureAudioFileURL() throws -> URL {
if let cached = Self.cachedFixtureAudioURL,
FileManager.default.fileExists(atPath: cached.path)
{
return cached
}
let url = FileManager.default.temporaryDirectory
.appendingPathComponent("lseend-fixture-\(UUID().uuidString)")
.appendingPathExtension("wav")
try writeFixtureAudio(to: url)
Self.cachedFixtureAudioURL = url
return url
}
private func writeFixtureAudio(to url: URL) throws {
let sampleRate = Double(Self.fixtureSampleRate)
let samples = makeFixtureSamples(sampleRate: sampleRate)
let format = AVAudioFormat(
commonFormat: .pcmFormatFloat32,
sampleRate: sampleRate,
channels: 1,
interleaved: false
)!
guard
let buffer = AVAudioPCMBuffer(
pcmFormat: format,
frameCapacity: AVAudioFrameCount(samples.count)
)
else {
XCTFail("Failed to allocate fixture audio buffer")
return
}
buffer.frameLength = AVAudioFrameCount(samples.count)
samples.withUnsafeBufferPointer { source in
guard let destination = buffer.floatChannelData?[0] else { return }
destination.update(from: source.baseAddress!, count: samples.count)
}
let file = try AVAudioFile(
forWriting: url,
settings: format.settings,
commonFormat: .pcmFormatFloat32,
interleaved: false
)
try file.write(from: buffer)
}
private func makeFixtureSamples(sampleRate: Double) -> [Float] {
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
(1.0, 0.20, 220),
(0.35, 0.00, 0),
(1.1, 0.32, 330),
(0.25, 0.00, 0),
(1.0, 0.28, 180),
(0.40, 0.00, 0),
(1.3, 0.36, 260),
(0.30, 0.00, 0),
(1.1, 0.24, 410),
]
var output: [Float] = []
for (duration, amplitude, frequency) in segments {
let frameCount = Int(duration * sampleRate)
guard amplitude > 0, frequency > 0 else {
output.append(contentsOf: repeatElement(0, count: frameCount))
continue
}
for frame in 0..<frameCount {
let time = Double(frame) / sampleRate
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
let carrier = sin(2.0 * Double.pi * frequency * time)
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
}
}
return output
}
private func duration(of samples: [Float], sampleRate: Int) -> Double {
Double(samples.count) / Double(sampleRate)
}
private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
var chunks: [[Float]] = []
var start = 0
var index = 0
while start < samples.count {
let size = sizes[index % sizes.count]
let stop = min(samples.count, start + size)
chunks.append(Array(samples[start..<stop]))
start = stop
index += 1
}
return chunks
}
private func assertResultInvariants(
_ result: LSEENDInferenceResult,
engine: LSEENDInferenceHelper,
@@ -433,15 +326,4 @@ final class LSEENDIntegrationTests: XCTestCase {
)
}
private func hasActiveSession(_ diarizer: LSEENDDiarizer) -> Bool {
let mirror = Mirror(reflecting: diarizer)
guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else {
XCTFail("Expected LS-EEND diarizer to expose _session via reflection")
return false
}
let optionalMirror = Mirror(reflecting: sessionValue)
XCTAssertEqual(optionalMirror.displayStyle, .optional)
return optionalMirror.children.count == 1
}
}
@@ -0,0 +1,326 @@
import XCTest
@testable import FluidAudio
final class LSEENDMatrixTests: XCTestCase {
// MARK: - Init (validated)
func testInitWithMatchingDimensions() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
XCTAssertEqual(m.rows, 2)
XCTAssertEqual(m.columns, 3)
XCTAssertEqual(m.values, [1, 2, 3, 4, 5, 6])
}
func testInitThrowsOnCountMismatch() {
XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3])) { error in
guard case LSEENDError.invalidMatrixShape = error else {
return XCTFail("Expected invalidMatrixShape, got \(error)")
}
}
}
func testInitThrowsOnNegativeRows() {
XCTAssertThrowsError(try LSEENDMatrix(rows: -1, columns: 3, values: [])) { error in
guard case LSEENDError.invalidMatrixShape = error else {
return XCTFail("Expected invalidMatrixShape, got \(error)")
}
}
}
func testInitThrowsOnNegativeColumns() {
XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: -1, values: [])) { error in
guard case LSEENDError.invalidMatrixShape = error else {
return XCTFail("Expected invalidMatrixShape, got \(error)")
}
}
}
func testInitWithZeroDimensions() throws {
let m = try LSEENDMatrix(rows: 0, columns: 5, values: [])
XCTAssertEqual(m.rows, 0)
XCTAssertEqual(m.columns, 5)
XCTAssertTrue(m.isEmpty)
}
// MARK: - Factory Methods
func testZeros() {
let m = LSEENDMatrix.zeros(rows: 3, columns: 2)
XCTAssertEqual(m.rows, 3)
XCTAssertEqual(m.columns, 2)
XCTAssertEqual(m.values, [Float](repeating: 0, count: 6))
}
func testEmpty() {
let m = LSEENDMatrix.empty(columns: 4)
XCTAssertEqual(m.rows, 0)
XCTAssertEqual(m.columns, 4)
XCTAssertTrue(m.isEmpty)
}
// MARK: - isEmpty
func testIsEmptyZeroRows() {
XCTAssertTrue(LSEENDMatrix.empty(columns: 3).isEmpty)
}
func testIsEmptyZeroColumns() {
let m = LSEENDMatrix(validatingRows: 3, columns: 0, values: [])
XCTAssertTrue(m.isEmpty)
}
func testIsEmptyFalseForPopulatedMatrix() throws {
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [42])
XCTAssertFalse(m.isEmpty)
}
// MARK: - Subscript
func testSubscriptGet() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [10, 20, 30, 40, 50, 60])
XCTAssertEqual(m[0, 0], 10)
XCTAssertEqual(m[0, 2], 30)
XCTAssertEqual(m[1, 0], 40)
XCTAssertEqual(m[1, 2], 60)
}
func testSubscriptSet() throws {
var m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
m[1, 0] = 99
XCTAssertEqual(m[1, 0], 99)
XCTAssertEqual(m.values, [1, 2, 99, 4])
}
// MARK: - row()
func testRow() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
XCTAssertEqual(Array(m.row(0)), [1, 2])
XCTAssertEqual(Array(m.row(1)), [3, 4])
XCTAssertEqual(Array(m.row(2)), [5, 6])
}
// MARK: - prefixingColumns
func testPrefixingColumns() throws {
let m = try LSEENDMatrix(rows: 2, columns: 4, values: [1, 2, 3, 4, 5, 6, 7, 8])
let prefix = m.prefixingColumns(2)
XCTAssertEqual(prefix.rows, 2)
XCTAssertEqual(prefix.columns, 2)
XCTAssertEqual(prefix.values, [1, 2, 5, 6])
}
func testPrefixingColumnsEqualToWidth() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
let same = m.prefixingColumns(3)
XCTAssertEqual(same, m)
}
func testPrefixingColumnsGreaterThanWidth() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
let same = m.prefixingColumns(10)
XCTAssertEqual(same, m)
}
func testPrefixingColumnsZero() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
let empty = m.prefixingColumns(0)
XCTAssertTrue(empty.isEmpty)
}
func testPrefixingColumnsOnEmptyMatrix() {
let m = LSEENDMatrix.empty(columns: 4)
let result = m.prefixingColumns(2)
XCTAssertTrue(result.isEmpty)
XCTAssertEqual(result.columns, 2)
}
// MARK: - rowMajorRows
func testRowMajorRows() throws {
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
let rows = m.rowMajorRows()
XCTAssertEqual(rows, [[1, 2, 3], [4, 5, 6]])
}
func testRowMajorRowsEmpty() {
let m = LSEENDMatrix.empty(columns: 3)
XCTAssertEqual(m.rowMajorRows(), [])
}
// MARK: - appendingRows
func testAppendingRows() throws {
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let b = try LSEENDMatrix(rows: 1, columns: 2, values: [5, 6])
let result = a.appendingRows(b)
XCTAssertEqual(result.rows, 3)
XCTAssertEqual(result.columns, 2)
XCTAssertEqual(result.values, [1, 2, 3, 4, 5, 6])
}
func testAppendingRowsToEmpty() throws {
let a = LSEENDMatrix.empty(columns: 2)
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
XCTAssertEqual(a.appendingRows(b), b)
}
func testAppendingEmptyRows() throws {
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let b = LSEENDMatrix.empty(columns: 2)
XCTAssertEqual(a.appendingRows(b), a)
}
// MARK: - droppingFirstRows
func testDroppingFirstRows() throws {
let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8])
let dropped = m.droppingFirstRows(2)
XCTAssertEqual(dropped.rows, 2)
XCTAssertEqual(dropped.columns, 2)
XCTAssertEqual(dropped.values, [5, 6, 7, 8])
}
func testDroppingAllRows() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let dropped = m.droppingFirstRows(3)
XCTAssertEqual(dropped.rows, 0)
XCTAssertTrue(dropped.isEmpty)
}
func testDroppingMoreThanTotalRows() throws {
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let dropped = m.droppingFirstRows(100)
XCTAssertEqual(dropped.rows, 0)
XCTAssertTrue(dropped.isEmpty)
}
func testDroppingZeroRows() throws {
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let same = m.droppingFirstRows(0)
XCTAssertEqual(same, m)
}
func testDroppingNegativeCount() throws {
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let same = m.droppingFirstRows(-5)
XCTAssertEqual(same, m)
}
// MARK: - slicingRows
func testSlicingRows() throws {
let m = try LSEENDMatrix(rows: 5, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let slice = m.slicingRows(start: 1, end: 4)
XCTAssertEqual(slice.rows, 3)
XCTAssertEqual(slice.values, [3, 4, 5, 6, 7, 8])
}
func testSlicingRowsFullRange() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let slice = m.slicingRows(start: 0, end: 3)
XCTAssertEqual(slice, m)
}
func testSlicingRowsEmptyRange() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let slice = m.slicingRows(start: 2, end: 2)
XCTAssertTrue(slice.isEmpty)
XCTAssertEqual(slice.columns, 2)
}
func testSlicingRowsClampsOutOfBounds() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let slice = m.slicingRows(start: -5, end: 100)
XCTAssertEqual(slice, m)
}
func testSlicingRowsInvertedRange() throws {
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let slice = m.slicingRows(start: 3, end: 1)
XCTAssertTrue(slice.isEmpty)
}
// MARK: - applyingSigmoid
func testSigmoidZero() throws {
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [0])
let s = m.applyingSigmoid()
XCTAssertEqual(s[0, 0], 0.5, accuracy: 1e-6)
}
func testSigmoidLargePositive() throws {
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [20])
let s = m.applyingSigmoid()
XCTAssertEqual(s[0, 0], 1.0, accuracy: 1e-5)
}
func testSigmoidLargeNegative() throws {
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [-20])
let s = m.applyingSigmoid()
XCTAssertEqual(s[0, 0], 0.0, accuracy: 1e-5)
}
func testSigmoidPreservesShape() throws {
let m = try LSEENDMatrix(rows: 3, columns: 4, values: [Float](repeating: 0, count: 12))
let s = m.applyingSigmoid()
XCTAssertEqual(s.rows, 3)
XCTAssertEqual(s.columns, 4)
XCTAssertEqual(s.values.count, 12)
}
func testSigmoidDoesNotMutateOriginal() throws {
let m = try LSEENDMatrix(rows: 1, columns: 2, values: [0, 0])
_ = m.applyingSigmoid()
XCTAssertEqual(m.values, [0, 0])
}
func testSigmoidOnEmpty() {
let m = LSEENDMatrix.empty(columns: 3)
let s = m.applyingSigmoid()
XCTAssertTrue(s.isEmpty)
}
// MARK: - Equatable
func testEqualMatrices() throws {
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
XCTAssertEqual(a, b)
}
func testUnequalValues() throws {
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 5])
XCTAssertNotEqual(a, b)
}
// MARK: - Roundtrip: append then drop
func testAppendThenDropRecoversOriginal() throws {
let original = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
let extra = try LSEENDMatrix(rows: 2, columns: 2, values: [7, 8, 9, 10])
let combined = original.appendingRows(extra)
let recovered = combined.slicingRows(start: 0, end: 3)
XCTAssertEqual(recovered, original)
}
func testSliceThenAppendRecombines() throws {
let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8])
let head = m.slicingRows(start: 0, end: 2)
let tail = m.slicingRows(start: 2, end: 4)
let recombined = head.appendingRows(tail)
XCTAssertEqual(recombined, m)
}
func testDropThenPrefixColumnsCommutes() throws {
let m = try LSEENDMatrix(rows: 4, columns: 4, values: (0..<16).map { Float($0) })
let dropFirst = m.droppingFirstRows(2).prefixingColumns(2)
let prefixFirst = m.prefixingColumns(2).droppingFirstRows(2)
XCTAssertEqual(dropFirst, prefixFirst)
}
}
@@ -12,7 +12,7 @@ final class SortformerTests: XCTestCase {
// Create 5 seconds of deterministic random audio
let sampleRate = 16000
let audioCount = sampleRate * 5
srand48(Int(Date().timeIntervalSince1970 * 1e6))
srand48(42)
let audio = (0..<audioCount).map { _ in Float(drand48() - 0.5) }
// 1. Get chunks from Batch Feature Provider
@@ -36,7 +36,7 @@ final class SortformerTests: XCTestCase {
// 3. Compare
XCTAssertEqual(batchChunks.count, streamingChunks.count, "Chunk count mismatch")
for i in 0..<min(batchChunks.count, streamingChunks.count) {
for i in 0..<batchChunks.count {
let batch = batchChunks[i]
let stream = streamingChunks[i]
@@ -140,18 +140,15 @@ final class SortformerTests: XCTestCase {
func testBufferBounds() throws {
var config = DiarizerTimelineConfig.sortformerDefault
let numSpeakers = config.numSpeakers
config.maxStoredFrames = 50
let maxFrames = 50
config.maxStoredFrames = maxFrames
// Create timeline with maxFrames limit
let timeline = DiarizerTimeline(config: config)
// Feed 200 frames of predictions (way more than maxFrames)
let totalFrames = 200
for frameOffset in stride(from: 0, to: totalFrames, by: 10) {
var chunkPreds: [Float] = []
for _ in 0..<10 {
chunkPreds.append(contentsOf: [Float](repeating: 0.5, count: numSpeakers))
}
for frameOffset in stride(from: 0, to: 200, by: 10) {
let chunkPreds = [Float](repeating: 0.5, count: 10 * numSpeakers)
let chunk = DiarizerChunkResult(
startFrame: frameOffset,
@@ -165,16 +162,25 @@ final class SortformerTests: XCTestCase {
// Verify framePredictions is bounded to maxFrames
XCTAssertLessThanOrEqual(
timeline.finalizedPredictions.count, config.maxStoredFrames! * config.numSpeakers,
timeline.finalizedPredictions.count, maxFrames * numSpeakers,
"framePredictions should be bounded to maxFrames")
// Verify we still have some predictions (not all trimmed)
XCTAssertGreaterThan(timeline.numFinalizedFrames, 0, "Should have some predictions")
// Verify probability() returns valid values for stored frames and NaN for trimmed frames
let storedFrames = timeline.finalizedPredictions.count / numSpeakers
let firstStoredFrame = timeline.numFinalizedFrames - storedFrames
XCTAssertFalse(
timeline.probability(speaker: 0, frame: firstStoredFrame).isNaN,
"First stored frame should have a valid probability")
XCTAssertTrue(
timeline.probability(speaker: 0, frame: firstStoredFrame - 1).isNaN,
"Frame before stored range should return NaN")
}
func testSegmentExtraction() throws {
let config = SortformerConfig.default
let config = DiarizerTimelineConfig.sortformerDefault
let numSpeakers = config.numSpeakers
// Create predictions with clear speaker pattern:
@@ -198,12 +204,13 @@ final class SortformerTests: XCTestCase {
let timeline = try DiarizerTimeline(
allPredictions: predictions,
config: .sortformerDefault,
config: config,
isComplete: true
)
// Check that we have segments
XCTAssertGreaterThan(timeline.speakers.count, 0, "Should have extracted segments")
// Check that segments were actually extracted (not just that the speakers dict exists)
let totalSegments = timeline.speakers.values.reduce(0) { $0 + $1.finalizedSegmentCount }
XCTAssertGreaterThan(totalSegments, 0, "Should have extracted at least one segment")
// Verify segment speakers are valid
for speaker in timeline.speakers.values {
@@ -216,10 +223,10 @@ final class SortformerTests: XCTestCase {
}
func testReset() throws {
let config = SortformerConfig.default
let config = DiarizerTimelineConfig.sortformerDefault
let numSpeakers = config.numSpeakers
let timeline = DiarizerTimeline(config: .sortformerDefault)
let timeline = DiarizerTimeline(config: config)
// Add some data
let chunk = DiarizerChunkResult(
@@ -58,9 +58,9 @@ final class SortformerTimelineTests: XCTestCase {
XCTAssertEqual(timeline.numFinalizedFrames, 18, "3 chunks of 6 frames = 18")
}
// MARK: - Segment Generation
// MARK: - Prediction Storage
func testHighProbabilityUpdatesFramePredictions() throws {
func testHighProbabilityStoresPredictions() throws {
let timeline = DiarizerTimeline(config: .sortformerDefault)
let numSpeakers = 4
let frameCount = 12
@@ -139,7 +139,6 @@ final class SortformerTimelineTests: XCTestCase {
// MARK: - Probability Access
func testProbabilityAccess() throws {
let numSpeakers = 4
// [f0s0=0.1, f0s1=0.2, f0s2=0.3, f0s3=0.4, f1s0=0.5, ...]
let predictions: [Float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
let timeline = try DiarizerTimeline(
@@ -155,7 +154,6 @@ final class SortformerTimelineTests: XCTestCase {
timeline.probability(speaker: 0, frame: 999).isNaN,
"Out of range should return NaN"
)
_ = numSpeakers // used above
}
// MARK: - SortformerSegment
@@ -172,9 +170,11 @@ final class SortformerTimelineTests: XCTestCase {
let a = DiarizerSegment(speakerIndex: 0, startFrame: 0, endFrame: 10, frameDurationSeconds: 0.08)
let b = DiarizerSegment(speakerIndex: 0, startFrame: 5, endFrame: 15, frameDurationSeconds: 0.08)
let c = DiarizerSegment(speakerIndex: 0, startFrame: 11, endFrame: 20, frameDurationSeconds: 0.08)
let d = DiarizerSegment(speakerIndex: 0, startFrame: 10, endFrame: 20, frameDurationSeconds: 0.08)
XCTAssertTrue(a.overlaps(with: b), "Overlapping segments")
XCTAssertFalse(a.overlaps(with: c), "Non-overlapping segments")
XCTAssertFalse(a.overlaps(with: c), "Non-overlapping segments (gap of 1)")
XCTAssertTrue(a.overlaps(with: d), "Touching segments (endFrame == startFrame) count as overlapping")
}
func testSegmentAbsorb() {
@@ -60,7 +60,7 @@ final class SortformerTypesTests: XCTestCase {
func testConfigIncompatibleWithDifferentShape() {
let a = SortformerConfig.default
let b = SortformerConfig.nvidiaHighLatencyV2_1
let b = SortformerConfig.highContextV2_1
XCTAssertFalse(a.isCompatible(with: b))
}
@@ -1,4 +1,3 @@
import AVFoundation
import Foundation
import XCTest
@@ -9,8 +8,6 @@ import XCTest
/// - `SortformerDiarizer.enrollSpeaker(withAudio:named:)`
/// - `LSEENDDiarizer.enrollSpeaker(withSamples:named:)`
final class SpeakerEnrollmentTests: XCTestCase {
private static let fixtureSampleRate = 16_000
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
nonisolated(unsafe) private static var cachedLseendEngine: LSEENDInferenceHelper?
private func loadSortformerModelsForTest(config: SortformerConfig) async throws -> SortformerModels {
@@ -163,11 +160,13 @@ final class SpeakerEnrollmentTests: XCTestCase {
let diarizer = SortformerDiarizer(config: config)
let models = try await loadSortformerModelsForTest(config: config)
diarizer.initialize(models: models)
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
XCTAssertNotNil(speaker)
try XCTSkipIf(
speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
XCTAssertEqual(speaker?.name, "Alice")
XCTAssertEqual(diarizer.numFramesProcessed, 0)
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
@@ -184,14 +183,17 @@ final class SpeakerEnrollmentTests: XCTestCase {
let diarizer = SortformerDiarizer(config: config)
let models = try await loadSortformerModelsForTest(config: config)
diarizer.initialize(models: models)
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let liveAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0)
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let liveAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0)
let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
XCTAssertNotNil(speaker)
try XCTSkipIf(
speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
var update: DiarizerTimelineUpdate?
for chunk in chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) {
for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) {
diarizer.addAudio(chunk)
if let next = try diarizer.process() {
update = next
@@ -217,17 +219,21 @@ final class SpeakerEnrollmentTests: XCTestCase {
let diarizer = SortformerDiarizer(config: config)
let models = try await loadSortformerModelsForTest(config: config)
diarizer.initialize(models: models)
let speakerAAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let speakerBAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0)
let speakerAAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let speakerBAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0)
let speakerA = try diarizer.enrollSpeaker(withAudio: speakerAAudio, named: "Alice")
XCTAssertNotNil(speakerA)
try XCTSkipIf(
speakerA == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
let stateAfterA = diarizer.state
let cachedLengthAfterA = stateAfterA.spkcacheLength + stateAfterA.fifoLength
let speakerB = try diarizer.enrollSpeaker(withAudio: speakerBAudio, named: "Bob")
XCTAssertNotNil(speakerB)
try XCTSkipIf(
speakerB == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
let stateAfterB = diarizer.state
XCTAssertGreaterThanOrEqual(
@@ -250,9 +256,12 @@ final class SpeakerEnrollmentTests: XCTestCase {
let diarizer = SortformerDiarizer(config: config)
let models = try await loadSortformerModelsForTest(config: config)
diarizer.initialize(models: models)
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
let firstSpeaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
try XCTSkipIf(
firstSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
let secondSpeaker = try diarizer.enrollSpeaker(
withAudio: enrollmentAudio,
named: "Bob",
@@ -287,7 +296,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
let engine = try await loadLseendEngineForTest()
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
let enrollmentAudio = try fixtureAudio(
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
@@ -298,7 +307,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
XCTAssertEqual(diarizer.numFramesProcessed, 0)
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 })
XCTAssertTrue(hasActiveLseendSession(diarizer))
XCTAssertTrue(diarizer.hasActiveSession)
}
func testLseendEnrollSpeakerFollowedByStreamingProcessingStartsAtFrameZero() async throws {
@@ -307,14 +316,15 @@ final class SpeakerEnrollmentTests: XCTestCase {
let engine = try await loadLseendEngineForTest()
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
let enrollmentAudio = try fixtureAudio(
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let liveAudio = try fixtureAudio(sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
let liveAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
var firstUpdate: DiarizerTimelineUpdate?
for chunk in chunk(liveAudio, sizes: [977, 1231, 1607]) {
for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [977, 1231, 1607]) {
if let update = try diarizer.process(samples: chunk) {
firstUpdate = update
break
@@ -338,9 +348,9 @@ final class SpeakerEnrollmentTests: XCTestCase {
let engine = try await loadLseendEngineForTest()
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
let speakerAAudio = try fixtureAudio(
let speakerAAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let speakerBAudio = try fixtureAudio(
let speakerBAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
let speakerA = try diarizer.enrollSpeaker(withSamples: speakerAAudio, named: "Alice")
@@ -348,7 +358,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
XCTAssertEqual(diarizer.numFramesProcessed, 0)
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
XCTAssertTrue(hasActiveLseendSession(diarizer))
XCTAssertTrue(diarizer.hasActiveSession)
let expectedNames = Set([speakerA?.name, speakerB?.name].compactMap { $0 })
XCTAssertEqual(Set(namedSpeakerNames(in: diarizer.timeline)), expectedNames)
}
@@ -359,11 +369,12 @@ final class SpeakerEnrollmentTests: XCTestCase {
let engine = try await loadLseendEngineForTest()
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
diarizer.initialize(engine: engine)
let enrollmentAudio = try fixtureAudio(
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
let firstSpeaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
try XCTSkipIf(firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.")
try XCTSkipIf(
firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.")
let secondSpeaker = try diarizer.enrollSpeaker(
withSamples: enrollmentAudio,
named: "Bob",
@@ -375,95 +386,6 @@ final class SpeakerEnrollmentTests: XCTestCase {
XCTAssertEqual(namedSpeakerNames(in: diarizer.timeline), ["Alice"])
}
private func fixtureAudio(sampleRate: Int, startSeconds: Double = 0.0, durationSeconds: Double) throws -> [Float] {
let converter = AudioConverter(sampleRate: Double(sampleRate))
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
let startSample = min(audio.count, Int(startSeconds * Double(sampleRate)))
let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate)))
return Array(audio[startSample..<endSample])
}
private func fixtureAudioFileURL() throws -> URL {
if let cached = Self.cachedFixtureAudioURL,
FileManager.default.fileExists(atPath: cached.path)
{
return cached
}
let url = FileManager.default.temporaryDirectory
.appendingPathComponent("speaker-enrollment-fixture-\(UUID().uuidString)")
.appendingPathExtension("wav")
try writeFixtureAudio(to: url)
Self.cachedFixtureAudioURL = url
return url
}
private func writeFixtureAudio(to url: URL) throws {
let sampleRate = Double(Self.fixtureSampleRate)
let samples = makeFixtureSamples(sampleRate: sampleRate)
let format = AVAudioFormat(
commonFormat: .pcmFormatFloat32,
sampleRate: sampleRate,
channels: 1,
interleaved: false
)!
guard
let buffer = AVAudioPCMBuffer(
pcmFormat: format,
frameCapacity: AVAudioFrameCount(samples.count)
)
else {
XCTFail("Failed to allocate fixture audio buffer")
return
}
buffer.frameLength = AVAudioFrameCount(samples.count)
samples.withUnsafeBufferPointer { source in
guard let destination = buffer.floatChannelData?[0] else { return }
destination.update(from: source.baseAddress!, count: samples.count)
}
let file = try AVAudioFile(
forWriting: url,
settings: format.settings,
commonFormat: .pcmFormatFloat32,
interleaved: false
)
try file.write(from: buffer)
}
private func makeFixtureSamples(sampleRate: Double) -> [Float] {
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
(1.0, 0.20, 220),
(0.35, 0.00, 0),
(1.1, 0.32, 330),
(0.25, 0.00, 0),
(1.0, 0.28, 180),
(0.40, 0.00, 0),
(1.3, 0.36, 260),
(0.30, 0.00, 0),
(1.1, 0.24, 410),
]
var output: [Float] = []
for (duration, amplitude, frequency) in segments {
let frameCount = Int(duration * sampleRate)
guard amplitude > 0, frequency > 0 else {
output.append(contentsOf: repeatElement(0, count: frameCount))
continue
}
for frame in 0..<frameCount {
let time = Double(frame) / sampleRate
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
let carrier = sin(2.0 * Double.pi * frequency * time)
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
}
}
return output
}
private func namedSpeakerIndices(in timeline: DiarizerTimeline) -> [Int] {
timeline.speakers.values
.filter { $0.name != nil }
@@ -476,29 +398,4 @@ final class SpeakerEnrollmentTests: XCTestCase {
.compactMap(\.name)
.sorted()
}
private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
var chunks: [[Float]] = []
var start = 0
var index = 0
while start < samples.count {
let size = sizes[index % sizes.count]
let stop = min(samples.count, start + size)
chunks.append(Array(samples[start..<stop]))
start = stop
index += 1
}
return chunks
}
private func hasActiveLseendSession(_ diarizer: LSEENDDiarizer) -> Bool {
let mirror = Mirror(reflecting: diarizer)
guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else {
XCTFail("Expected LS-EEND diarizer to expose _session via reflection")
return false
}
let optionalMirror = Mirror(reflecting: sessionValue)
return optionalMirror.displayStyle == .optional && optionalMirror.children.count == 1
}
}