mirror of
https://github.com/FluidInference/FluidAudio.git
synced 2026-05-12 20:20:36 +00:00
fix: clean up diarization test infrastructure (#395)
## Summary - Extract shared fixture helpers into `DiarizationTestFixtures` enum, removing ~200 lines of duplicate code across `LSEENDIntegrationTests` and `SpeakerEnrollmentTests` - Replace fragile `Mirror`-based private state inspection with `internal` `hasActiveSession` property on `LSEENDDiarizerAPI` - Fix non-deterministic `srand48` seed in `SortformerTests` (use constant `42` instead of time-based seed) - Fix asymmetric skip guards in Sortformer enrollment tests (`XCTSkipIf` instead of `XCTAssertNotNil` for host-dependent segments) ## Test plan - [x] `swift build --build-tests` passes - [ ] `swift test --filter SortformerTests` passes - [ ] `swift test --filter LSEENDIntegrationTests` passes - [ ] `swift test --filter SpeakerEnrollmentTests` passes <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/fluidinference/fluidaudio/pull/395" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This commit is contained in:
+13
-123
@@ -75,161 +75,51 @@ Use `OfflineDiarizerManager` when you need offline DER parity or want to run the
|
||||
|
||||
### Diarizer Protocol
|
||||
|
||||
`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, which provides a unified streaming and offline API.
|
||||
`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, providing a unified streaming and offline API.
|
||||
|
||||
**Protocol Properties:**
|
||||
- `isAvailable: Bool` — Whether the model is loaded and ready
|
||||
- `numFramesProcessed: Int` — Confirmed frames processed so far
|
||||
- `targetSampleRate: Int?` — Model's expected audio sample rate in Hz
|
||||
- `modelFrameHz: Double?` — Output frame rate in Hz (frames per second)
|
||||
- `numSpeakers: Int?` — Number of real speaker output tracks
|
||||
- `timeline: DiarizerTimeline` — Accumulated diarization results
|
||||
**Streaming:** `addAudio(_:sourceSampleRate:)` → `process()` → read `timeline`. Convenience `process(samples:sourceSampleRate:)` combines both steps. Returns `DiarizerTimelineUpdate?` (`nil` when not enough audio has accumulated).
|
||||
|
||||
**Streaming:**
|
||||
- `addAudio<C: Collection>(_ samples: C, sourceSampleRate: Double?) throws` — Buffer audio for processing; pass a non-nil `sourceSampleRate` to resample on the fly
|
||||
- `process() throws -> DiarizerTimelineUpdate?` — Run inference on buffered audio; returns `nil` if not enough audio has accumulated
|
||||
- `process<C: Collection>(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate?` — Convenience combining `addAudio` + `process` in one call
|
||||
**Offline:** `processComplete(_:sourceSampleRate:...)` or `processComplete(audioFileURL:...)` to process a full recording in one call.
|
||||
|
||||
**Offline:**
|
||||
- `processComplete<C: Collection>(_ samples: C, sourceSampleRate:, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Process a complete audio buffer in one call
|
||||
- `processComplete(audioFileURL: URL, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Read, resample, and process an audio file end-to-end
|
||||
**Speaker Enrollment:** `enrollSpeaker(withAudio:sourceSampleRate:named:...)` feeds known-speaker audio before streaming to label a slot.
|
||||
|
||||
**Speaker Enrollment:**
|
||||
- `enrollSpeaker<C: Collection>(withAudio samples: C, sourceSampleRate:, named:, overwritingAssignedSpeakerName:) throws -> DiarizerSpeaker?` — Feed audio of a known speaker before streaming begins; warms model state and labels that speaker's slot for subsequent `process()` calls
|
||||
|
||||
**Lifecycle:**
|
||||
- `reset()` — Clear all streaming state (session, buffers, timeline) while keeping the model loaded
|
||||
- `cleanup()` — Release all resources including the loaded model
|
||||
**Lifecycle:** `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything.
|
||||
|
||||
---
|
||||
|
||||
### DiarizerTimeline
|
||||
### DiarizerTimeline & DiarizerSpeaker
|
||||
|
||||
Holds accumulated streaming predictions and derived speaker segments. Returned by `Diarizer.timeline` and `processComplete(...)`.
|
||||
`DiarizerTimeline` accumulates per-frame speaker probabilities and derives `DiarizerSpeaker` segments. Each speaker has `finalizedSegments` (confirmed) and `tentativeSegments` (may be revised). Segments expose `startTime`, `endTime`, `duration`, and `isFinalized`.
|
||||
|
||||
**Key Properties:**
|
||||
- `config: DiarizerTimelineConfig` — Post-processing configuration used to build segments
|
||||
- `speakers: [Int: DiarizerSpeaker]` — Speaker slots keyed by output track index
|
||||
- `finalizedPredictions: [Float]` — Flat `[frames × numSpeakers]` array of finalized per-frame probabilities
|
||||
- `tentativePredictions: [Float]` — Same layout; frames still within the right-context window that may be revised
|
||||
- `numFinalizedFrames: Int` — Count of finalized frames
|
||||
- `numTentativeFrames: Int` — Count of tentative frames
|
||||
- `finalizedDuration: Float` — Duration in seconds of finalized audio
|
||||
- `hasSegments: Bool` — Whether any speaker has at least one segment
|
||||
|
||||
**Mutation:**
|
||||
- `addChunk(_ chunk: DiarizerChunkResult) throws -> DiarizerTimelineUpdate` — Append new predictions and rebuild segments; called internally by the diarizer
|
||||
- `rebuild(finalizedPredictions:tentativePredictions:keepingSpeakers:isComplete:) throws` — Replace all predictions from scratch (used by offline processing)
|
||||
- `reset(keepingSpeakers:)` / `reset(keepingSpeakersWhere:)` — Clear segments and optionally preserve named speakers or speaker metadata
|
||||
- `finalize()` — Promote all tentative segments to finalized
|
||||
|
||||
**`DiarizerTimelineConfig`** — Shared configuration used by both diarizers:
|
||||
| Parameter | Default | Description |
|
||||
|---|---|---|
|
||||
| `numSpeakers` | model-specific | Number of speaker output tracks |
|
||||
| `frameDurationSeconds` | model-specific | Duration of one output frame |
|
||||
| `onsetThreshold` | 0.5 | Probability threshold to begin a speech segment |
|
||||
| `offsetThreshold` | 0.5 | Probability threshold to end a speech segment |
|
||||
| `onsetPadFrames` | 0 | Frames prepended to each segment onset |
|
||||
| `offsetPadFrames` | 0 | Frames appended to each segment offset |
|
||||
| `minFramesOn` | 0 | Minimum segment length; shorter segments are dropped |
|
||||
| `minFramesOff` | 0 | Minimum gap; shorter silences are closed |
|
||||
| `maxStoredFrames` | nil | Rolling window cap on finalized frames (nil = unlimited) |
|
||||
|
||||
---
|
||||
|
||||
### DiarizerSpeaker
|
||||
|
||||
Represents a single speaker track within a `DiarizerTimeline`.
|
||||
|
||||
**Key Properties:**
|
||||
- `id: UUID` — Stable identity across resets
|
||||
- `index: Int` — Slot index in the diarizer output (0-based)
|
||||
- `name: String?` — Optional display name (set via enrollment or manually)
|
||||
- `finalizedSegments: [DiarizerSegment]` — Confirmed speech segments
|
||||
- `tentativeSegments: [DiarizerSegment]` — Speculative segments within the right-context window
|
||||
- `hasSegments: Bool` — Whether any finalized or tentative segments exist
|
||||
- `numSpeechFrames: Int` — Total frames spanned by all segments (finalized + tentative)
|
||||
- `speechDuration: Float` — Total speech duration in seconds
|
||||
|
||||
**`DiarizerSegment`** — A single time-range for one speaker:
|
||||
- `startFrame / endFrame: Int` — Frame indices (convert using `frameDurationSeconds`)
|
||||
- `startTime / endTime: Float` — Seconds
|
||||
- `duration: Float` — Segment length in seconds
|
||||
- `isFinalized: Bool` — Whether the segment has been confirmed
|
||||
**`DiarizerTimelineConfig`** controls post-processing (onset/offset thresholds default to 0.5, min segment/gap duration, optional rolling window cap). Both diarizers accept this at init.
|
||||
|
||||
---
|
||||
|
||||
### SortformerDiarizer
|
||||
|
||||
End-to-end streaming diarization using NVIDIA's Sortformer model. Tracks **4 fixed speaker slots**.
|
||||
Streaming diarization using NVIDIA's Sortformer. 4 fixed speaker slots, 16 kHz input, 80 ms frame duration.
|
||||
|
||||
- **Sample rate:** 16 kHz
|
||||
- **Frame duration:** 80 ms (12.5 Hz output)
|
||||
- **Streaming latency:** ~0.64 s (`default` config) or ~1.04 s (`nvidiaLowLatency` configs)
|
||||
- **Accuracy:** 31.7% DER on AMI SDM (`nvidiaHighLatencyV2_1`; other configs untested)
|
||||
|
||||
**Initialization:**
|
||||
```swift
|
||||
// Preferred: download and compile model automatically
|
||||
let diarizer = SortformerDiarizer(config: .default, timelineConfig: .sortformerDefault)
|
||||
try await diarizer.initialize(mainModelPath: modelURL)
|
||||
|
||||
// Or with pre-loaded models
|
||||
diarizer.initialize(models: sortformerModels)
|
||||
```
|
||||
|
||||
**`SortformerConfig` Presets:**
|
||||
|
||||
| Preset | Latency | Notes |
|
||||
|---|---|---|
|
||||
| `.default` / `.fastestV2_1` | 1.04 s | Fastest inference speed and update rate. |
|
||||
| `.fastestV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights. |
|
||||
| `.nvidiaLowLatencyV2_1` | 1.04 s | Uses more context than `fastest`; Improvement is minimal |
|
||||
| `.nvidiaLowLatencyV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights |
|
||||
| `.nvidiaHighLatencyV2_1` | 30.4 s | 31.7% DER on AMI SDM |
|
||||
| `.nvidiaHighLatencyV2_0` | 30.4 s | Uses NVIDIA's Sortformer v2 weights |
|
||||
|
||||
All streaming methods are defined by the `Diarizer` protocol above. Additionally:
|
||||
- `state: SortformerStreamingState` — Live speaker cache and FIFO queue state (for diagnostics)
|
||||
- `config: SortformerConfig` — The configuration this instance was created with
|
||||
**Config presets:** `.default` / `.fastV2_1` (1.04 s latency), `.balancedV2_1` (1.04 s, 20.6% DER on AMI SDM), `.highContextV2_1` (30.4 s latency). v2 variants also available.
|
||||
|
||||
---
|
||||
|
||||
### LSEENDDiarizer
|
||||
|
||||
End-to-end streaming diarization using LS-EEND (Linear Streaming End-to-End Neural Diarization). Supports a **variable number of speaker slots** depending on the model variant.
|
||||
Streaming diarization using LS-EEND. Variable speaker slots, 8 kHz input, 100 ms frame duration, 20.7% DER on AMI SDM.
|
||||
|
||||
- **Sample rate:** 8 kHz
|
||||
- **Frame duration:** 100 ms (10 Hz output)
|
||||
- **Accuracy:** 20.7% DER on AMI SDM (AMI variant)
|
||||
- **Variants:** `LSEENDVariant` (`LSEENDModelDescriptor.LSEENDVariant`)
|
||||
|
||||
**Initialization:**
|
||||
```swift
|
||||
// Auto-download from HuggingFace
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
try await diarizer.initialize(variant: .dihard3)
|
||||
|
||||
// Or with an explicit descriptor
|
||||
let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: .dihard3)
|
||||
try diarizer.initialize(descriptor: descriptor)
|
||||
```
|
||||
|
||||
**LS-EEND–Specific Properties:**
|
||||
- `computeUnits: MLComputeUnits` — CoreML compute target (`.cpuOnly` is typically fastest)
|
||||
- `streamingLatencySeconds: Double?` — Minimum audio required before first output frame
|
||||
- `decodeMaxSpeakers: Int?` — Total output slots including internal boundary tracks
|
||||
- `timelineConfig: DiarizerTimelineConfig` — Active post-processing configuration
|
||||
**Variants:** ami, callhome, dihard2, dihard3 (via `LSEENDModelDescriptor.loadFromHuggingFace(variant:)`).
|
||||
|
||||
**Additional Method:**
|
||||
- `finalizeSession() throws -> DiarizerChunkResult?` — Flush pending audio and finalize the timeline; call at end of a stream before reading the final timeline
|
||||
|
||||
**`LSEENDModelDescriptor`:**
|
||||
- `LSEENDModelDescriptor.loadFromHuggingFace(variant:cacheDirectory:computeUnits:) async throws -> LSEENDModelDescriptor` — Download and cache all model files; returns a descriptor ready for `initialize(descriptor:)`
|
||||
- `init(variant:modelURL:metadataURL:)` — Construct from local paths if already cached
|
||||
|
||||
All streaming and offline methods are defined by the `Diarizer` protocol above.
|
||||
Call `finalizeSession()` at end-of-stream to flush pending audio before reading the final timeline.
|
||||
|
||||
## Voice Activity Detection
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ let result: LSEENDInferenceResult = try engine.infer(audioFileURL: url)
|
||||
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate)
|
||||
|
||||
// Or with a caller-owned mel spectrogram (for thread isolation)
|
||||
let mel = NeMoMelSpectrogram(...)
|
||||
let mel = AudioMelSpectrogram(...)
|
||||
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate, melSpectrogram: mel)
|
||||
```
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ Audio (16kHz) → Mel Spectrogram → CoreML Model → Speaker Probabilities
|
||||
|
||||
The pipeline consists of:
|
||||
|
||||
1. **Mel Spectrogram** (`NeMoMelSpectrogram`): Converts raw audio to 128-bin mel features
|
||||
1. **Mel Spectrogram** (`AudioMelSpectrogram`): Converts raw audio to 128-bin mel features
|
||||
2. **CoreML Model** (`DiarizerInference`): Combined encoder + attention + head
|
||||
3. **Streaming State** (`SortformerStreamingState`): Maintains speaker cache and FIFO queue
|
||||
4. **Post-processing** (`SortformerTimeline`): Converts probabilities to speaker segments
|
||||
@@ -79,8 +79,8 @@ FIFO Queue Role:
|
||||
| Config | `fifoLen` | Effect |
|
||||
|--------|-----------|--------|
|
||||
| Default | 40 | Smaller memory, faster compression cycles |
|
||||
| NVIDIA Low | 188 | Larger context before compression |
|
||||
| NVIDIA High | 40 | Same as default |
|
||||
| Balanced | 188 | Larger context before compression |
|
||||
| High Context | 40 | Same as default |
|
||||
|
||||
When `fifoLen + newChunkFrames > fifoLen` capacity, frames are popped from FIFO and either:
|
||||
1. Added to speaker cache (if speaker was active)
|
||||
@@ -105,8 +105,8 @@ Chunk with Context:
|
||||
| Config | `rightContext` | Look-ahead | Latency Impact |
|
||||
|--------|----------------|------------|----------------|
|
||||
| Default | 7 | 7 × 80ms = 560ms | Low latency |
|
||||
| NVIDIA Low | 7 | 7 × 80ms = 560ms | Low latency |
|
||||
| NVIDIA High | 40 | 40 × 80ms = 3.2s | High latency, better quality |
|
||||
| Balanced | 7 | 7 × 80ms = 560ms | Low latency |
|
||||
| High Context | 40 | 40 × 80ms = 3.2s | High latency, better quality |
|
||||
|
||||
**Why Right Context Matters:**
|
||||
|
||||
@@ -142,7 +142,7 @@ Default config:
|
||||
= 13 × 8 × 0.01
|
||||
= 1.04 seconds
|
||||
|
||||
NVIDIA High Latency config:
|
||||
High Context config:
|
||||
= (340 + 40) × 8 × 160 / 16000
|
||||
= 380 × 8 × 0.01
|
||||
= 30.4 seconds
|
||||
@@ -191,11 +191,11 @@ Defines streaming parameters that must match the CoreML model's static shapes:
|
||||
// Default (~1.04s latency, lowest latency)
|
||||
SortformerConfig.default
|
||||
|
||||
// NVIDIA High Latency (30.4s latency, best quality)
|
||||
SortformerConfig.nvidiaHighLatency
|
||||
// Balanced (1.04s latency, best quality on AMI SDM)
|
||||
SortformerConfig.balancedV2_1
|
||||
|
||||
// NVIDIA Low Latency (1.04s latency)
|
||||
SortformerConfig.nvidiaLowLatency
|
||||
// High Context (30.4s latency, most context)
|
||||
SortformerConfig.highContextV2_1
|
||||
```
|
||||
|
||||
### Pipeline.swift
|
||||
@@ -375,9 +375,11 @@ public struct SortformerSegment {
|
||||
|
||||
| Config | Chunk Size | Latency | Quality |
|
||||
|--------|------------|---------|---------|
|
||||
| `default` | 6 frames | ~1.04s | Good |
|
||||
| `nvidiaLowLatency` | 6 frames | ~1.04s | Better |
|
||||
| `nvidiaHighLatency` | 340 frames | ~30.4s | Best |
|
||||
| `default` / `fastV2_1` | 6 frames | ~1.04s | Good |
|
||||
| `balancedV2_1` | 6 frames | ~1.04s | Best (20.6% DER on AMI SDM) |
|
||||
| `highContextV2_1` | 340 frames | ~30.4s | Good (31.7% DER on AMI SDM) |
|
||||
|
||||
> **Note:** v2.1 variants may degrade when many speakers are talking simultaneously. v2 variants (`fastV2`, `balancedV2`, `highContextV2`) are available as alternatives.
|
||||
|
||||
Latency is determined by:
|
||||
- `chunkLen * subsamplingFactor * melStride / sampleRate`
|
||||
@@ -396,10 +398,10 @@ This preserves the most informative historical context while bounding memory usa
|
||||
|
||||
## Post-Processing
|
||||
|
||||
`SortformerPostProcessingConfig` controls segment extraction:
|
||||
`DiarizerTimelineConfig` controls segment extraction:
|
||||
|
||||
```swift
|
||||
let config = SortformerPostProcessingConfig(
|
||||
let config = DiarizerTimelineConfig(
|
||||
onsetThreshold: 0.5, // Probability to start speech
|
||||
offsetThreshold: 0.5, // Probability to end speech
|
||||
minDurationOn: 0.25, // Min speech segment (seconds)
|
||||
@@ -414,8 +416,8 @@ Three CoreML models are available on HuggingFace:
|
||||
| Variant | File | Config |
|
||||
|---------|------|--------|
|
||||
| Default | `Sortformer.mlmodelc` | `SortformerConfig.default` |
|
||||
| NVIDIA Low | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.nvidiaLowLatency` |
|
||||
| NVIDIA High | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.nvidiaHighLatency` |
|
||||
| Balanced | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.balancedV2_1` |
|
||||
| High Context | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.highContextV2_1` |
|
||||
|
||||
**Important:** Each model has baked-in static shapes. You must use the matching configuration.
|
||||
|
||||
@@ -447,8 +449,8 @@ audioEngine.installTap { buffer in
|
||||
### Batch Processing
|
||||
|
||||
```swift
|
||||
let diarizer = SortformerDiarizer(config: .nvidiaHighLatency)
|
||||
let models = try await SortformerModels.loadFromHuggingFace(config: .nvidiaHighLatency)
|
||||
let diarizer = SortformerDiarizer(config: .highContextV2_1)
|
||||
let models = try await SortformerModels.loadFromHuggingFace(config: .highContextV2_1)
|
||||
diarizer.initialize(models: models)
|
||||
|
||||
let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_000)
|
||||
@@ -457,9 +459,9 @@ let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_0
|
||||
let fileTimeline = try diarizer.processComplete(audioFileURL: audioURL)
|
||||
|
||||
// Get segments per speaker
|
||||
for (speakerIndex, segments) in timeline.segments.enumerated() {
|
||||
for segment in segments {
|
||||
print("Speaker \(speakerIndex): \(segment.startTime)s - \(segment.endTime)s")
|
||||
for (index, speaker) in timeline.speakers {
|
||||
for segment in speaker.finalizedSegments {
|
||||
print("Speaker \(index): \(segment.startTime)s - \(segment.endTime)s")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -335,25 +335,7 @@ swift run fluidaudiocli diarization-benchmark --mode offline --auto-download \
|
||||
|
||||
### LS-EEND (LongForm Streaming End-to-End Neural Diarization)
|
||||
|
||||
State-of-the-art end-to-end streaming diarization with fully local CoreML inference. This is the best default option for online and streaming diarization when you want low-latency speaker activity updates from a single model without a separate clustering pipeline. Multiple exported variants are available for different benchmark domains, and the diarizer supports both streaming and complete-buffer processing.
|
||||
|
||||
Why use LS-EEND:
|
||||
- Robust to noise and high speaker overlap
|
||||
- Supports up to 10 speakers in a session
|
||||
- Generally achieves better benchmark results than Sortformer on CALLHOME, AMI, and DIHARD III
|
||||
- Frame-by-frame inference with 100ms frames, and 900ms of tentative frames.
|
||||
- Much more lightweight than Sortformer, and can be faster on CPU than Sortformer on ANE
|
||||
- Does not suffer from the missed-quiet-speech issues that Sortformer can show
|
||||
- End-to-end model without separate segmentation, embedding extraction, and clustering stages
|
||||
- Also works well for complete-buffer inference when you want the same model in offline mode
|
||||
|
||||
Tradeoffs:
|
||||
- Can pick up background speakers more readily than Sortformer
|
||||
- Speaker identity stability is usually somewhat weaker than Sortformer
|
||||
- Speaker indices are expected to follow chronological arrival order, but Sortformer tends to maintain that ordering more reliably
|
||||
- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline
|
||||
|
||||
See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for model loading and integration details.
|
||||
End-to-end streaming diarization with CoreML inference. Default choice for online diarization — single model, no clustering pipeline, up to 10 speakers, 100ms frame updates with 900ms tentative preview. Supports both streaming and complete-buffer processing. See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for details.
|
||||
|
||||
```swift
|
||||
import FluidAudio
|
||||
@@ -373,26 +355,9 @@ Task {
|
||||
|
||||
### Sortformer (End-to-End Neural Diarization)
|
||||
|
||||
End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). This is the secondary streaming/online diarizer behind LS-EEND. It offers nearly the same live-update workflow and API flexibility, but trades LS-EEND's stronger benchmark results and higher speaker capacity for better identity stability. No separate VAD, segmentation, or clustering needed. Limited to 4 speakers and does not remember speakers across recordings. Licensed under NVIDIA Open Model License (no restrictions).
|
||||
End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). Secondary streaming diarizer — trades LS-EEND's higher speaker capacity and benchmark results for better speaker identity stability. Limited to 4 speakers. No separate VAD, segmentation, or clustering needed. Licensed under NVIDIA Open Model License.
|
||||
|
||||
Why use Sortformer:
|
||||
- Simple streaming pipeline with low latency
|
||||
- Supports live-update flexibility like LS-EEND
|
||||
- Handles overlapping speech without an external clustering stage
|
||||
- Better speaker identity stability than LS-EEND
|
||||
- Ignores speech from background conversations
|
||||
- Good choice when the 4-speaker limit is acceptable
|
||||
|
||||
Tradeoffs:
|
||||
- Limited to 4 unique speakers
|
||||
- Updates are performed less frequently than for LS-EEND
|
||||
- Can get overwhelmed when there are too many background speakers
|
||||
- Can miss speech when the audio is too quiet
|
||||
- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline
|
||||
|
||||
Like LS-EEND, Sortformer supports ultra-low-latency inference with 0.5s updates, 0.5s preview frames, and a fixed 1.0s latency before finalized predictions. Both models emit their results into a `DiarizerTimeline`, so you do not need to manage incremental diarization state externally.
|
||||
|
||||
See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage, comparison with Pyannote, streaming config, and architecture details.
|
||||
Both LS-EEND and Sortformer emit results into a `DiarizerTimeline` with ultra-low-latency updates. See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage and comparison.
|
||||
|
||||
### Streaming/Online Speaker Diarization (Pyannote)
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ HF_TOKEN="your_token" python nemo_ami_benchmark.py --output results.json
|
||||
|
||||
### High-Latency Streaming Config
|
||||
|
||||
These settings match the Swift `SortformerConfig.nvidiaHighLatency`:
|
||||
These settings match the Swift `SortformerConfig.highContextV2_1`:
|
||||
|
||||
| Parameter | Value | Description |
|
||||
|-----------|-------|-------------|
|
||||
|
||||
@@ -172,7 +172,7 @@ public actor StreamingEouAsrManager {
|
||||
private var rnntDecoder: RnntDecoder?
|
||||
private let audioConverter = AudioConverter()
|
||||
private var tokenizer: Tokenizer?
|
||||
private let melProcessor = NeMoMelSpectrogram() // Native Swift mel spectrogram
|
||||
private let melProcessor = AudioMelSpectrogram() // Native Swift mel spectrogram
|
||||
|
||||
// Configuration - now based on chunkSize
|
||||
public let chunkSize: StreamingChunkSize
|
||||
@@ -247,7 +247,7 @@ public actor StreamingEouAsrManager {
|
||||
public func loadModels(modelDir: URL) async throws {
|
||||
logger.info("Loading CoreML models from \(modelDir.path)...")
|
||||
|
||||
// No longer loading preprocessor - using native Swift NeMoMelSpectrogram instead
|
||||
// No longer loading preprocessor - using native Swift AudioMelSpectrogram instead
|
||||
self.streamingEncoder = try await MLModel.load(
|
||||
contentsOf: modelDir.appendingPathComponent("streaming_encoder.mlmodelc"), configuration: self.configuration
|
||||
)
|
||||
@@ -450,7 +450,7 @@ public actor StreamingEouAsrManager {
|
||||
let mel = try MLMultiArray(shape: [1, 128, NSNumber(value: numFrames)], dataType: .float32)
|
||||
let melPtr = mel.dataPointer.bindMemory(to: Float.self, capacity: mel.count)
|
||||
|
||||
// NeMoMelSpectrogram returns [nMels, T] row-major (mel bin, then time)
|
||||
// AudioMelSpectrogram returns [nMels, T] row-major (mel bin, then time)
|
||||
// CoreML expects [1, 128, T] which is the same layout
|
||||
melPtr.update(from: melFlat, count: melFlat.count)
|
||||
|
||||
|
||||
@@ -648,16 +648,6 @@ public final class DiarizerTimeline {
|
||||
queue.sync { _tentativePredictions }
|
||||
}
|
||||
|
||||
/// Total number of frames (finalized + tentative)
|
||||
@available(
|
||||
*, deprecated,
|
||||
message: "`numFrames` now includes tentative frames. Use 'numFinalizedFrames' for only finalized frames.",
|
||||
renamed: "numFinalizedFrames"
|
||||
)
|
||||
public var numFrames: Int {
|
||||
queue.sync { _numFinalizedFrames + _tentativePredictions.count / config.numSpeakers }
|
||||
}
|
||||
|
||||
/// Total number of finalized frames
|
||||
public var numFinalizedFrames: Int {
|
||||
queue.sync { _numFinalizedFrames }
|
||||
@@ -677,31 +667,11 @@ public final class DiarizerTimeline {
|
||||
speakers.values.contains(where: \.hasSegments)
|
||||
}
|
||||
|
||||
/// Duration of all predictions in seconds
|
||||
@available(
|
||||
*, deprecated,
|
||||
message: "`duration` now includes tentative frames. Use 'finalizedDuration' for only finalized frames.",
|
||||
renamed: "finalizedDuration"
|
||||
)
|
||||
public var duration: Float {
|
||||
Float(numFrames) * config.frameDurationSeconds
|
||||
}
|
||||
|
||||
/// Duration of finalized predictions in seconds
|
||||
public var finalizedDuration: Float {
|
||||
Float(numFinalizedFrames) * config.frameDurationSeconds
|
||||
}
|
||||
|
||||
/// Duration of tentative predictions in seconds
|
||||
@available(
|
||||
*, deprecated,
|
||||
message: "tentativeDuration now excludes finalized frames. Use 'duration' for the full timeline duration.",
|
||||
renamed: "duration"
|
||||
)
|
||||
public var tentativeDuration: Float {
|
||||
Float(numTentativeFrames) * config.frameDurationSeconds
|
||||
}
|
||||
|
||||
private var _finalizedPredictions: [Float] = []
|
||||
private var _tentativePredictions: [Float] = []
|
||||
private var _speakers: [Int: DiarizerSpeaker] = [:]
|
||||
@@ -1106,42 +1076,6 @@ public final class DiarizerTimeline {
|
||||
}
|
||||
}
|
||||
|
||||
extension DiarizerTimeline {
|
||||
@available(*, deprecated, renamed: "finalizedPredictions")
|
||||
public var framePredictions: [Float] { finalizedPredictions }
|
||||
|
||||
@available(
|
||||
*, deprecated,
|
||||
message: "Use Timeline.speakers[index].confirmedSegments to access a speaker's confirmed segments."
|
||||
)
|
||||
public var segments: [[DiarizerSegment]] {
|
||||
queue.sync {
|
||||
var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers)
|
||||
for (index, speaker) in _speakers {
|
||||
result[index] = speaker.finalizedSegments
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
@available(
|
||||
*, deprecated,
|
||||
message: "Use Timeline.speakers[index].tentativeSegments to access a speaker's tentative segments."
|
||||
)
|
||||
public var tentativeSegments: [[DiarizerSegment]] {
|
||||
queue.sync {
|
||||
var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers)
|
||||
for (index, speaker) in _speakers {
|
||||
result[index] = speaker.tentativeSegments
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
@available(*, deprecated, renamed: "numTentativeFrames")
|
||||
public var numTentative: Int { numTentativeFrames }
|
||||
}
|
||||
|
||||
// MARK: - Timeline Update
|
||||
|
||||
public struct DiarizerTimelineUpdate: Sendable {
|
||||
|
||||
@@ -80,11 +80,18 @@ public final class LSEENDDiarizer: Diarizer {
|
||||
return _engine?.decodeMaxSpeakers
|
||||
}
|
||||
|
||||
/// Whether a streaming session is currently active.
|
||||
var hasActiveSession: Bool {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
return _session != nil
|
||||
}
|
||||
|
||||
// MARK: - Private State
|
||||
|
||||
private var _engine: LSEENDInferenceHelper?
|
||||
private var _session: LSEENDStreamingSession?
|
||||
private var _melSpectrogram: NeMoMelSpectrogram?
|
||||
private var _melSpectrogram: AudioMelSpectrogram?
|
||||
private var _timeline: DiarizerTimeline
|
||||
private var _numFramesProcessed: Int = 0
|
||||
private var _timelineConfig: DiarizerTimelineConfig
|
||||
@@ -737,8 +744,8 @@ public final class LSEENDDiarizer: Diarizer {
|
||||
}
|
||||
|
||||
/// Create a new mel spectrogram instance owned by this diarizer.
|
||||
private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> NeMoMelSpectrogram {
|
||||
NeMoMelSpectrogram(
|
||||
private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> AudioMelSpectrogram {
|
||||
AudioMelSpectrogram(
|
||||
sampleRate: featureConfig.sampleRate,
|
||||
nMels: featureConfig.nMels,
|
||||
nFFT: featureConfig.nFFT,
|
||||
|
||||
@@ -54,7 +54,7 @@ private final class LSEENDInferenceSharedResources {
|
||||
let modelFrameHz: Double
|
||||
let streamingLatencySeconds: Double
|
||||
let decodeMaxSpeakers: Int
|
||||
let melSpectrogram: NeMoMelSpectrogram
|
||||
let melSpectrogram: AudioMelSpectrogram
|
||||
let offlineFeatureExtractor: LSEENDOfflineFeatureExtractor
|
||||
|
||||
// Preallocated ANE-aligned input arrays reused across predictStep calls
|
||||
@@ -77,7 +77,7 @@ private final class LSEENDInferenceSharedResources {
|
||||
modelFrameHz = metadata.frameHz
|
||||
streamingLatencySeconds = metadata.streamingLatencySeconds
|
||||
decodeMaxSpeakers = metadata.maxNspks
|
||||
melSpectrogram = NeMoMelSpectrogram(
|
||||
melSpectrogram = AudioMelSpectrogram(
|
||||
sampleRate: featureConfig.sampleRate,
|
||||
nMels: featureConfig.nMels,
|
||||
nFFT: featureConfig.nFFT,
|
||||
@@ -150,7 +150,7 @@ public final class LSEENDInferenceHelper {
|
||||
/// Maximum number of speaker slots in the model output (including boundary tracks).
|
||||
public var decodeMaxSpeakers: Int { sharedResources.decodeMaxSpeakers }
|
||||
|
||||
fileprivate var melSpectrogram: NeMoMelSpectrogram { sharedResources.melSpectrogram }
|
||||
fileprivate var melSpectrogram: AudioMelSpectrogram { sharedResources.melSpectrogram }
|
||||
private var offlineFeatureExtractor: LSEENDOfflineFeatureExtractor { sharedResources.offlineFeatureExtractor }
|
||||
|
||||
private let lock = NSLock()
|
||||
@@ -191,8 +191,9 @@ public final class LSEENDInferenceHelper {
|
||||
/// - inputSampleRate: Must match ``targetSampleRate``.
|
||||
/// - melSpectrogram: A mel spectrogram instance owned by the caller.
|
||||
/// - Returns: A session that accepts audio via ``LSEENDStreamingSession/pushAudio(_:)``.
|
||||
public func createSession(inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram) throws -> LSEENDStreamingSession
|
||||
{
|
||||
public func createSession(
|
||||
inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram
|
||||
) throws -> LSEENDStreamingSession {
|
||||
try LSEENDStreamingSession(engine: self, inputSampleRate: inputSampleRate, melSpectrogram: melSpectrogram)
|
||||
}
|
||||
|
||||
@@ -472,7 +473,7 @@ public final class LSEENDStreamingSession {
|
||||
fileprivate var emittedFrames = 0
|
||||
|
||||
fileprivate init(
|
||||
engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram? = nil
|
||||
engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram? = nil
|
||||
) throws {
|
||||
guard inputSampleRate == engine.targetSampleRate else {
|
||||
throw LSEENDError.unsupportedAudio(
|
||||
|
||||
@@ -45,8 +45,8 @@ public struct LSEENDFeatureConfig: Sendable, Hashable {
|
||||
}
|
||||
}
|
||||
|
||||
private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpectrogram {
|
||||
NeMoMelSpectrogram(
|
||||
private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> AudioMelSpectrogram {
|
||||
AudioMelSpectrogram(
|
||||
sampleRate: config.sampleRate,
|
||||
nMels: config.nMels,
|
||||
nFFT: config.nFFT,
|
||||
@@ -70,14 +70,14 @@ private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpe
|
||||
/// For incremental processing, use ``LSEENDStreamingFeatureExtractor`` instead.
|
||||
public final class LSEENDOfflineFeatureExtractor {
|
||||
private let config: LSEENDFeatureConfig
|
||||
private let spectrogram: NeMoMelSpectrogram
|
||||
private let spectrogram: AudioMelSpectrogram
|
||||
|
||||
/// Creates an offline feature extractor.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - metadata: Model metadata from which feature parameters are derived.
|
||||
/// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`.
|
||||
public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) {
|
||||
public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) {
|
||||
let featureConfig = LSEENDFeatureConfig(metadata: metadata)
|
||||
config = featureConfig
|
||||
self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig)
|
||||
@@ -182,7 +182,7 @@ public final class LSEENDOfflineFeatureExtractor {
|
||||
/// - Important: This class is **not** thread-safe. All calls must be serialized externally.
|
||||
public final class LSEENDStreamingFeatureExtractor {
|
||||
private let config: LSEENDFeatureConfig
|
||||
private let spectrogram: NeMoMelSpectrogram
|
||||
private let spectrogram: AudioMelSpectrogram
|
||||
|
||||
private var audioBuffer: [Float] = []
|
||||
private var audioStartSample = 0
|
||||
@@ -201,7 +201,7 @@ public final class LSEENDStreamingFeatureExtractor {
|
||||
/// - Parameters:
|
||||
/// - metadata: Model metadata from which feature parameters are derived.
|
||||
/// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`.
|
||||
public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) {
|
||||
public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) {
|
||||
let featureConfig = LSEENDFeatureConfig(metadata: metadata)
|
||||
config = featureConfig
|
||||
self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig)
|
||||
|
||||
@@ -409,12 +409,6 @@ public struct OfflineDiarizerConfig: Sendable {
|
||||
set { embedding.excludeOverlap = newValue }
|
||||
}
|
||||
|
||||
@available(*, deprecated, renamed: "embeddingExcludeOverlap")
|
||||
public var shouldExcludeOverlaps: Bool {
|
||||
get { embeddingExcludeOverlap }
|
||||
set { embeddingExcludeOverlap = newValue }
|
||||
}
|
||||
|
||||
public var minSegmentDuration: Double {
|
||||
get { embedding.minSegmentDurationSeconds }
|
||||
set { embedding.minSegmentDurationSeconds = newValue }
|
||||
|
||||
@@ -65,7 +65,7 @@ public final class SortformerDiarizer: Diarizer {
|
||||
private var _models: SortformerModels?
|
||||
|
||||
// Native mel spectrogram (used when useNativePreprocessing is enabled)
|
||||
private let melSpectrogram = NeMoMelSpectrogram()
|
||||
private let melSpectrogram = AudioMelSpectrogram()
|
||||
|
||||
// Audio buffering
|
||||
private var audioBuffer: [Float] = []
|
||||
@@ -80,17 +80,6 @@ public final class SortformerDiarizer: Diarizer {
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
@available(
|
||||
*, deprecated, renamed: "init(config:timelineConfig:)",
|
||||
message: "Use `timelineConfig` instead of `postProcessingConfig`."
|
||||
)
|
||||
public convenience init(
|
||||
config: SortformerConfig = .default,
|
||||
postProcessingConfig: DiarizerTimelineConfig = .default(numSpeakers: 4, frameDurationSeconds: 0.08)
|
||||
) {
|
||||
self.init(config: config, timelineConfig: postProcessingConfig)
|
||||
}
|
||||
|
||||
public init(
|
||||
config: SortformerConfig = .default,
|
||||
timelineConfig: DiarizerTimelineConfig = .sortformerDefault
|
||||
@@ -103,12 +92,6 @@ public final class SortformerDiarizer: Diarizer {
|
||||
self._timeline = DiarizerTimeline(config: timelineConfig)
|
||||
}
|
||||
|
||||
/// Backward-compatible initializer accepting the legacy SortformerPostProcessingConfig.
|
||||
@available(*, deprecated, message: "Use init(config:timelineConfig:) with DiarizerTimelineConfig instead")
|
||||
public convenience init(config: SortformerConfig = .default, postProcessingConfig: SortformerPostProcessingConfig) {
|
||||
self.init(config: config, timelineConfig: postProcessingConfig.toDiarizerConfig())
|
||||
}
|
||||
|
||||
/// Initialize with CoreML models (combined pipeline mode).
|
||||
///
|
||||
/// - Parameters:
|
||||
@@ -399,18 +382,6 @@ public final class SortformerDiarizer: Diarizer {
|
||||
///
|
||||
/// Convenience method that combines `addAudio()` and `process()`.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - samples: Audio samples (16kHz mono)
|
||||
/// - sourceSampleRate: Source audio sample rate
|
||||
/// - Returns: New chunk results if enough audio was processed
|
||||
@available(*, deprecated, renamed: "process(samples:)")
|
||||
public func processSamples(
|
||||
_ samples: [Float],
|
||||
sourceSampleRate: Double? = nil
|
||||
) throws -> DiarizerTimelineUpdate? {
|
||||
return try process(samples: samples, sourceSampleRate: sourceSampleRate)
|
||||
}
|
||||
|
||||
/// Process a chunk of audio in one call.
|
||||
///
|
||||
/// Convenience method that combines `addAudio()` and `process()`.
|
||||
|
||||
@@ -121,9 +121,10 @@ public struct SortformerConfig: Sendable {
|
||||
spkcacheUpdatePeriod: 31
|
||||
)
|
||||
|
||||
/// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2 weights
|
||||
public static let `fastestV2` = SortformerConfig(
|
||||
modelVariant: .fastestV2,
|
||||
/// Fast config with Sortformer v2 weights (~1.04s latency, smallest context).
|
||||
/// May handle high-speaker-count scenarios better than v2.1 (v2.1 can degrade when many speakers overlap).
|
||||
public static let `fastV2` = SortformerConfig(
|
||||
modelVariant: .fastV2,
|
||||
chunkLen: 6,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 7,
|
||||
@@ -132,9 +133,10 @@ public struct SortformerConfig: Sendable {
|
||||
spkcacheUpdatePeriod: 31
|
||||
)
|
||||
|
||||
/// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2.1 weights
|
||||
public static let `fastestV2_1` = SortformerConfig(
|
||||
modelVariant: .fastestV2_1,
|
||||
/// Fast config with Sortformer v2.1 weights (~1.04s latency, smallest context).
|
||||
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
|
||||
public static let `fastV2_1` = SortformerConfig(
|
||||
modelVariant: .fastV2_1,
|
||||
chunkLen: 6,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 7,
|
||||
@@ -143,39 +145,10 @@ public struct SortformerConfig: Sendable {
|
||||
spkcacheUpdatePeriod: 31
|
||||
)
|
||||
|
||||
/// Backwards compatible alias for NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights
|
||||
@available(*, deprecated, renamed: "nvidiaHighLatencyV2_1")
|
||||
public static let nvidiaHighLatency = nvidiaHighLatencyV2_1
|
||||
|
||||
/// NVIDIA's 30.4s latency configuration with Sortformer v2 weights
|
||||
public static let nvidiaHighLatencyV2 = SortformerConfig(
|
||||
modelVariant: .nvidiaHighLatencyV2,
|
||||
chunkLen: 340,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 40,
|
||||
fifoLen: 40,
|
||||
spkcacheLen: 188,
|
||||
spkcacheUpdatePeriod: 300
|
||||
)
|
||||
|
||||
/// NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights
|
||||
public static let nvidiaHighLatencyV2_1 = SortformerConfig(
|
||||
modelVariant: .nvidiaHighLatencyV2_1,
|
||||
chunkLen: 340,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 40,
|
||||
fifoLen: 40,
|
||||
spkcacheLen: 188,
|
||||
spkcacheUpdatePeriod: 300
|
||||
)
|
||||
|
||||
/// Backwards compatible alias for NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights
|
||||
@available(*, deprecated, renamed: "nvidiaLowLatencyV2_1")
|
||||
public static let nvidiaLowLatency = nvidiaLowLatencyV2_1
|
||||
|
||||
/// NVIDIA's 1.04s latency configuration with Sortformer v2 weights
|
||||
public static let nvidiaLowLatencyV2 = SortformerConfig(
|
||||
modelVariant: .nvidiaLowLatencyV2,
|
||||
/// Balanced config with Sortformer v2 weights (~1.04s latency, larger FIFO for better quality).
|
||||
/// 20.57% DER on AMI SDM. May handle high-speaker-count scenarios better than v2.1.
|
||||
public static let balancedV2 = SortformerConfig(
|
||||
modelVariant: .balancedV2,
|
||||
chunkLen: 6,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 7,
|
||||
@@ -184,9 +157,11 @@ public struct SortformerConfig: Sendable {
|
||||
spkcacheUpdatePeriod: 144
|
||||
)
|
||||
|
||||
/// NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights (20.57% DER on AMI SDM)
|
||||
public static let nvidiaLowLatencyV2_1 = SortformerConfig(
|
||||
modelVariant: .nvidiaLowLatencyV2_1,
|
||||
/// Balanced config with Sortformer v2.1 weights (~1.04s latency, larger FIFO for better quality).
|
||||
/// 20.57% DER on AMI SDM.
|
||||
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
|
||||
public static let balancedV2_1 = SortformerConfig(
|
||||
modelVariant: .balancedV2_1,
|
||||
chunkLen: 6,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 7,
|
||||
@@ -195,9 +170,33 @@ public struct SortformerConfig: Sendable {
|
||||
spkcacheUpdatePeriod: 144
|
||||
)
|
||||
|
||||
/// High-context config with Sortformer v2 weights (~30.4s latency, most context window).
|
||||
/// May handle high-speaker-count scenarios better than v2.1.
|
||||
public static let highContextV2 = SortformerConfig(
|
||||
modelVariant: .highContextV2,
|
||||
chunkLen: 340,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 40,
|
||||
fifoLen: 40,
|
||||
spkcacheLen: 188,
|
||||
spkcacheUpdatePeriod: 300
|
||||
)
|
||||
|
||||
/// High-context config with Sortformer v2.1 weights (~30.4s latency, most context window).
|
||||
/// - Note: v2.1 may degrade when many speakers are talking simultaneously.
|
||||
public static let highContextV2_1 = SortformerConfig(
|
||||
modelVariant: .highContextV2_1,
|
||||
chunkLen: 340,
|
||||
chunkLeftContext: 1,
|
||||
chunkRightContext: 40,
|
||||
fifoLen: 40,
|
||||
spkcacheLen: 188,
|
||||
spkcacheUpdatePeriod: 300
|
||||
)
|
||||
|
||||
/// - Warning: If you don't use one of the default configurations, you must use a local model converted with that configuration.
|
||||
public init(
|
||||
modelVariant: ModelVariant? = .fastestV2_1,
|
||||
modelVariant: ModelVariant? = .fastV2_1,
|
||||
chunkLen: Int = 6,
|
||||
chunkLeftContext: Int = 1,
|
||||
chunkRightContext: Int = 7,
|
||||
@@ -239,131 +238,6 @@ public struct SortformerConfig: Sendable {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for post-processing Sortformer diarizer predictions
|
||||
@available(*, deprecated, message: "Use DiarizerTimelineConfig instead", renamed: "DiarizerTimelineConfig")
|
||||
public struct SortformerPostProcessingConfig {
|
||||
/// Onset threshold for detecting the beginning and end of a speech
|
||||
public var onsetThreshold: Float
|
||||
|
||||
/// Offset threshold for detecting the end of a speech
|
||||
public var offsetThreshold: Float
|
||||
|
||||
/// Adding frames before each speech segment
|
||||
public var onsetPadFrames: Int
|
||||
|
||||
/// Adding frames after each speech segment
|
||||
public var offsetPadFrames: Int
|
||||
|
||||
/// Threshold for short speech segment deletion in frames
|
||||
public var minFramesOn: Int
|
||||
|
||||
/// Threshold for small non-speech deletion in frames
|
||||
public var minFramesOff: Int
|
||||
|
||||
/// Adding durations before each speech segment
|
||||
public var onsetPadSeconds: Float {
|
||||
get { Float(onsetPadFrames) * frameDurationSeconds }
|
||||
set { onsetPadFrames = Int(round(newValue / frameDurationSeconds)) }
|
||||
}
|
||||
|
||||
/// Adding durations after each speech segment
|
||||
public var offsetPadSeconds: Float {
|
||||
get { Float(offsetPadFrames) * frameDurationSeconds }
|
||||
set { offsetPadFrames = Int(round(newValue / frameDurationSeconds)) }
|
||||
}
|
||||
|
||||
/// Threshold for short speech segment deletion (seconds)
|
||||
public var minDurationOn: Float {
|
||||
get { Float(minFramesOn) * frameDurationSeconds }
|
||||
set { minFramesOn = Int(round(newValue / frameDurationSeconds)) }
|
||||
}
|
||||
|
||||
/// Threshold for small non-speech deletion (seconds)
|
||||
public var minDurationOff: Float {
|
||||
get { Float(minFramesOff) * frameDurationSeconds }
|
||||
set { minFramesOff = Int(round(newValue / frameDurationSeconds)) }
|
||||
}
|
||||
|
||||
/// Maximum number of predictions to retain
|
||||
public var maxStoredFrames: Int? = nil
|
||||
|
||||
/// Number of speakers
|
||||
public let numSpeakers: Int = 4
|
||||
|
||||
/// Number of speakers
|
||||
public let frameDurationSeconds: Float = 0.08
|
||||
|
||||
/// Default configurations
|
||||
public static var `default`: SortformerPostProcessingConfig {
|
||||
SortformerPostProcessingConfig(
|
||||
onsetThreshold: 0.5,
|
||||
offsetThreshold: 0.5,
|
||||
onsetPadFrames: 0,
|
||||
offsetPadFrames: 0,
|
||||
minFramesOn: 0,
|
||||
minFramesOff: 0
|
||||
)
|
||||
}
|
||||
|
||||
public init(
|
||||
onsetThreshold: Float = 0.5,
|
||||
offsetThreshold: Float = 0.5,
|
||||
onsetPadSeconds: Float = 0,
|
||||
offsetPadSeconds: Float = 0,
|
||||
minDurationOn: Float = 0,
|
||||
minDurationOff: Float = 0,
|
||||
maxStoredFrames: Int? = nil
|
||||
) {
|
||||
self.onsetThreshold = onsetThreshold
|
||||
self.offsetThreshold = offsetThreshold
|
||||
self.onsetPadFrames = Int(round(onsetPadSeconds / frameDurationSeconds))
|
||||
self.offsetPadFrames = Int(round(offsetPadSeconds / frameDurationSeconds))
|
||||
self.minFramesOn = Int(round(minDurationOn / frameDurationSeconds))
|
||||
self.minFramesOff = Int(round(minDurationOff / frameDurationSeconds))
|
||||
self.maxStoredFrames = maxStoredFrames
|
||||
}
|
||||
|
||||
public init(
|
||||
onsetThreshold: Float = 0.5,
|
||||
offsetThreshold: Float = 0.5,
|
||||
onsetPadFrames: Int = 0,
|
||||
offsetPadFrames: Int = 0,
|
||||
minFramesOn: Int = 0,
|
||||
minFramesOff: Int = 0,
|
||||
maxStoredFrames: Int? = nil
|
||||
) {
|
||||
self.onsetThreshold = onsetThreshold
|
||||
self.offsetThreshold = offsetThreshold
|
||||
self.onsetPadFrames = onsetPadFrames
|
||||
self.offsetPadFrames = offsetPadFrames
|
||||
self.minFramesOn = minFramesOn
|
||||
self.minFramesOff = minFramesOff
|
||||
self.maxStoredFrames = maxStoredFrames
|
||||
}
|
||||
|
||||
/// Convert to the unified DiarizerPostProcessingConfig
|
||||
public func toDiarizerConfig() -> DiarizerTimelineConfig {
|
||||
DiarizerTimelineConfig(
|
||||
numSpeakers: numSpeakers,
|
||||
frameDurationSeconds: frameDurationSeconds,
|
||||
onsetThreshold: onsetThreshold,
|
||||
offsetThreshold: offsetThreshold,
|
||||
onsetPadFrames: onsetPadFrames,
|
||||
offsetPadFrames: offsetPadFrames,
|
||||
minFramesOn: minFramesOn,
|
||||
minFramesOff: minFramesOff,
|
||||
maxStoredFrames: maxStoredFrames
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Backward-compatible typealiases — SortformerTimeline is now DiarizerTimeline
|
||||
@available(*, deprecated, renamed: "DiarizerTimeline")
|
||||
public typealias SortformerTimeline = DiarizerTimeline
|
||||
|
||||
@available(*, deprecated, renamed: "DiarizerSegment")
|
||||
public typealias SortformerSegment = DiarizerSegment
|
||||
|
||||
// MARK: - Streaming State
|
||||
|
||||
/// State maintained across streaming chunks for Sortformer diarization.
|
||||
@@ -455,7 +329,7 @@ public struct SortformerFeatureLoader: Sendable {
|
||||
|
||||
self.startFeat = 0
|
||||
self.endFeat = 0
|
||||
(self.featSeq, self.featLength, self.featSeqLength) = NeMoMelSpectrogram().computeFlatTransposed(audio: audio)
|
||||
(self.featSeq, self.featLength, self.featSeqLength) = AudioMelSpectrogram().computeFlatTransposed(audio: audio)
|
||||
// numChunks accounts for right context requirement: need endFeat + rc <= featLength
|
||||
// Chunk n has endFeat = (n+1) * chunkLen, so valid when (n+1) * chunkLen + rc <= featLength
|
||||
// numChunks = floor((featLength - rc) / chunkLen)
|
||||
@@ -522,10 +396,6 @@ public struct StreamingUpdateResult: Sendable {
|
||||
}
|
||||
}
|
||||
|
||||
/// Backward-compatible typealias — SortformerChunkResult is now DiarizerChunkResult
|
||||
@available(*, deprecated, renamed: "DiarizerChunkResult")
|
||||
public typealias SortformerChunkResult = DiarizerChunkResult
|
||||
|
||||
// MARK: - Errors
|
||||
|
||||
public enum SortformerError: Error, LocalizedError {
|
||||
|
||||
@@ -381,14 +381,15 @@ public class DownloadUtils {
|
||||
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedFilePath)
|
||||
let request = authorizedRequest(url: fileURL)
|
||||
|
||||
let tempFileURL: URL
|
||||
let httpResponse: HTTPURLResponse
|
||||
|
||||
if let handler = progressHandler {
|
||||
let baseBytes = completedBytes
|
||||
let fileCount = filesToDownload.count
|
||||
let totalBytesSnapshot = totalBytes
|
||||
httpResponse = try await downloadWithProgress(
|
||||
(tempFileURL, httpResponse) = try await downloadWithProgress(
|
||||
request: request,
|
||||
destinationURL: destPath,
|
||||
onProgress: { bytesWritten, _ in
|
||||
guard totalBytesSnapshot > 0 else { return }
|
||||
let current = baseBytes + bytesWritten
|
||||
@@ -406,41 +407,11 @@ public class DownloadUtils {
|
||||
guard let resp = response as? HTTPURLResponse else {
|
||||
throw HuggingFaceDownloadError.invalidResponse
|
||||
}
|
||||
tempFileURL = url
|
||||
httpResponse = resp
|
||||
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
message: "Rate limited while downloading \(file.path)")
|
||||
}
|
||||
|
||||
guard (200..<300).contains(httpResponse.statusCode) else {
|
||||
throw HuggingFaceDownloadError.downloadFailed(
|
||||
path: file.path,
|
||||
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
|
||||
)
|
||||
}
|
||||
|
||||
// Remove existing file if present (handles parallel download race conditions)
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
try? FileManager.default.removeItem(at: destPath)
|
||||
}
|
||||
try FileManager.default.moveItem(at: url, to: destPath)
|
||||
completedBytes += Int64(max(0, file.size))
|
||||
if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 {
|
||||
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files")
|
||||
}
|
||||
|
||||
// Report completed-file progress
|
||||
progressHandler?(
|
||||
DownloadProgress(
|
||||
fractionCompleted: totalBytes > 0
|
||||
? 0.5 * Double(completedBytes) / Double(totalBytes)
|
||||
: 0.5 * Double(index + 1) / Double(filesToDownload.count),
|
||||
phase: .downloading(completedFiles: index + 1, totalFiles: filesToDownload.count)
|
||||
))
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
@@ -454,13 +425,18 @@ public class DownloadUtils {
|
||||
)
|
||||
}
|
||||
|
||||
// Move downloaded file to destination
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
try? FileManager.default.removeItem(at: destPath)
|
||||
}
|
||||
try FileManager.default.moveItem(at: tempFileURL, to: destPath)
|
||||
|
||||
completedBytes += Int64(max(0, file.size))
|
||||
|
||||
if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 {
|
||||
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files")
|
||||
}
|
||||
|
||||
// Report completed-file progress
|
||||
progressHandler?(
|
||||
DownloadProgress(
|
||||
fractionCompleted: totalBytes > 0
|
||||
@@ -485,16 +461,17 @@ public class DownloadUtils {
|
||||
|
||||
/// Download a single file using a delegate to get byte-level progress.
|
||||
///
|
||||
/// This is a pure transport helper — the caller is responsible for validating
|
||||
/// the HTTP status and moving the temporary file to its final destination.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - request: The URLRequest to download.
|
||||
/// - onProgress: Called with `(bytesWritten, totalBytesExpected)` as data arrives.
|
||||
/// - Parameter destinationURL: Final destination path for the downloaded file.
|
||||
/// - Returns: The HTTP response for the completed download.
|
||||
/// - onProgress: Called with `(totalBytesWritten, totalBytesExpected)` as data arrives.
|
||||
/// - Returns: The temporary file URL and HTTP response.
|
||||
private static func downloadWithProgress(
|
||||
request: URLRequest,
|
||||
destinationURL: URL,
|
||||
onProgress: @escaping @Sendable (Int64, Int64) -> Void
|
||||
) async throws -> HTTPURLResponse {
|
||||
) async throws -> (URL, HTTPURLResponse) {
|
||||
let delegate = DownloadProgressDelegate(onProgress: onProgress)
|
||||
// Dedicated session with delegate — one per download to avoid cross-talk.
|
||||
let session = URLSession(
|
||||
@@ -510,25 +487,107 @@ public class DownloadUtils {
|
||||
throw HuggingFaceDownloadError.invalidResponse
|
||||
}
|
||||
|
||||
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
message: "HTTP \(httpResponse.statusCode)"
|
||||
)
|
||||
return (tempURL, httpResponse)
|
||||
}
|
||||
|
||||
/// Download a specific subdirectory from a HuggingFace repository.
|
||||
///
|
||||
/// Use this for optional model components that aren't part of the required model set
|
||||
/// (e.g., the Mimi encoder for PocketTTS voice cloning).
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - repo: The HuggingFace repository.
|
||||
/// - subdirectory: Path within the repo to download (e.g. `"mimi_encoder.mlmodelc"`).
|
||||
/// - repoDirectory: Local directory corresponding to the repo root.
|
||||
/// Files are saved at `repoDirectory/<remote_path>`.
|
||||
public static func downloadSubdirectory(
|
||||
_ repo: Repo,
|
||||
subdirectory: String,
|
||||
to repoDirectory: URL
|
||||
) async throws {
|
||||
var filesToDownload: [(path: String, size: Int)] = []
|
||||
|
||||
func listFiles(at path: String) async throws {
|
||||
let dirURL = try ModelRegistry.apiModels(repo.remotePath, "tree/main/\(path)")
|
||||
let (dirData, response) = try await fetchWithAuth(from: dirURL)
|
||||
if let httpResponse = response as? HTTPURLResponse,
|
||||
httpResponse.statusCode == 429 || httpResponse.statusCode == 503
|
||||
{
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
message: "Rate limited while listing files in \(path)")
|
||||
}
|
||||
guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else {
|
||||
return
|
||||
}
|
||||
for item in items {
|
||||
guard let itemPath = item["path"] as? String,
|
||||
let itemType = item["type"] as? String
|
||||
else { continue }
|
||||
|
||||
if itemType == "directory" {
|
||||
try await listFiles(at: itemPath)
|
||||
} else if itemType == "file" {
|
||||
let fileSize = item["size"] as? Int ?? -1
|
||||
filesToDownload.append((path: itemPath, size: fileSize))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
guard (200..<300).contains(httpResponse.statusCode) else {
|
||||
throw HuggingFaceDownloadError.downloadFailed(
|
||||
path: destinationURL.lastPathComponent,
|
||||
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
|
||||
try await listFiles(at: subdirectory)
|
||||
logger.info("Found \(filesToDownload.count) files in \(subdirectory)")
|
||||
|
||||
for (index, file) in filesToDownload.enumerated() {
|
||||
let destPath = repoDirectory.appendingPathComponent(file.path)
|
||||
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
continue
|
||||
}
|
||||
|
||||
try FileManager.default.createDirectory(
|
||||
at: destPath.deletingLastPathComponent(),
|
||||
withIntermediateDirectories: true
|
||||
)
|
||||
|
||||
if file.size == 0 {
|
||||
FileManager.default.createFile(atPath: destPath.path, contents: Data())
|
||||
continue
|
||||
}
|
||||
|
||||
let encodedPath =
|
||||
file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
|
||||
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedPath)
|
||||
let request = authorizedRequest(url: fileURL)
|
||||
|
||||
let (tempURL, response) = try await sharedSession.download(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
throw HuggingFaceDownloadError.invalidResponse
|
||||
}
|
||||
|
||||
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
message: "Rate limited while downloading \(file.path)")
|
||||
}
|
||||
|
||||
guard (200..<300).contains(httpResponse.statusCode) else {
|
||||
throw HuggingFaceDownloadError.downloadFailed(
|
||||
path: file.path,
|
||||
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
|
||||
)
|
||||
}
|
||||
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
try? FileManager.default.removeItem(at: destPath)
|
||||
}
|
||||
try FileManager.default.moveItem(at: tempURL, to: destPath)
|
||||
|
||||
if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 {
|
||||
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) \(subdirectory) files")
|
||||
}
|
||||
}
|
||||
|
||||
if FileManager.default.fileExists(atPath: destinationURL.path) {
|
||||
try? FileManager.default.removeItem(at: destinationURL)
|
||||
}
|
||||
try FileManager.default.moveItem(at: tempURL, to: destinationURL)
|
||||
return httpResponse
|
||||
logger.info("Downloaded \(subdirectory) from \(repo.folderName)")
|
||||
}
|
||||
|
||||
/// Fetch a single file from HuggingFace with retry
|
||||
|
||||
@@ -240,44 +240,44 @@ public enum ModelNames {
|
||||
/// Sortformer streaming diarization model names
|
||||
public enum Sortformer {
|
||||
public enum Variant: CaseIterable, Sendable {
|
||||
case fastestV2
|
||||
case fastestV2_1
|
||||
case nvidiaLowLatencyV2
|
||||
case nvidiaLowLatencyV2_1
|
||||
case nvidiaHighLatencyV2
|
||||
case nvidiaHighLatencyV2_1
|
||||
case fastV2
|
||||
case fastV2_1
|
||||
case balancedV2
|
||||
case balancedV2_1
|
||||
case highContextV2
|
||||
case highContextV2_1
|
||||
|
||||
public var name: String {
|
||||
switch self {
|
||||
case .fastestV2:
|
||||
case .fastV2:
|
||||
return "Sortformer_v2"
|
||||
case .fastestV2_1:
|
||||
case .fastV2_1:
|
||||
return "Sortformer_v2.1"
|
||||
case .nvidiaLowLatencyV2:
|
||||
case .balancedV2:
|
||||
return "SortformerNvidiaLow_v2"
|
||||
case .nvidiaLowLatencyV2_1:
|
||||
case .balancedV2_1:
|
||||
return "SortformerNvidiaLow_v2.1"
|
||||
case .nvidiaHighLatencyV2:
|
||||
case .highContextV2:
|
||||
return "SortformerNvidiaHigh_v2"
|
||||
case .nvidiaHighLatencyV2_1:
|
||||
case .highContextV2_1:
|
||||
return "SortformerNvidiaHigh_v2.1"
|
||||
}
|
||||
}
|
||||
|
||||
public var defaultConfiguration: SortformerConfig {
|
||||
switch self {
|
||||
case .fastestV2:
|
||||
return .fastestV2
|
||||
case .fastestV2_1:
|
||||
return .fastestV2_1
|
||||
case .nvidiaLowLatencyV2:
|
||||
return .nvidiaLowLatencyV2
|
||||
case .nvidiaLowLatencyV2_1:
|
||||
return .nvidiaLowLatencyV2_1
|
||||
case .nvidiaHighLatencyV2:
|
||||
return .nvidiaHighLatencyV2
|
||||
case .nvidiaHighLatencyV2_1:
|
||||
return .nvidiaHighLatencyV2_1
|
||||
case .fastV2:
|
||||
return .fastV2
|
||||
case .fastV2_1:
|
||||
return .fastV2_1
|
||||
case .balancedV2:
|
||||
return .balancedV2
|
||||
case .balancedV2_1:
|
||||
return .balancedV2_1
|
||||
case .highContextV2:
|
||||
return .highContextV2
|
||||
case .highContextV2_1:
|
||||
return .highContextV2_1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +291,7 @@ public enum ModelNames {
|
||||
}
|
||||
|
||||
/// Lowest latency for streaming
|
||||
public static let defaultVariant: Variant = .fastestV2_1
|
||||
public static let defaultVariant: Variant = .fastV2_1
|
||||
|
||||
/// Bundle name for a specific variant
|
||||
public static func bundle(for variant: Variant) -> String {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import Accelerate
|
||||
import CoreML
|
||||
import Darwin
|
||||
import Foundation
|
||||
import Metal
|
||||
|
||||
/// Shared ANE optimization utilities for all ML pipelines
|
||||
public enum ANEMemoryUtils {
|
||||
@@ -120,13 +118,16 @@ public enum ANEMemoryUtils {
|
||||
shape: [NSNumber],
|
||||
strides: [NSNumber]? = nil
|
||||
) throws -> MLMultiArray {
|
||||
// Validate bounds
|
||||
// Validate bounds using stride-aware backing size (accounts for ANE padding)
|
||||
let elementSize = getElementSize(for: array.dataType)
|
||||
let totalElements = shape.map { $0.intValue }.reduce(1, *)
|
||||
let bytesNeeded = totalElements * elementSize
|
||||
let viewStrides = strides ?? calculateOptimalStrides(for: shape)
|
||||
let viewBackingElements =
|
||||
shape.isEmpty ? 0 : viewStrides[0].intValue * shape[0].intValue
|
||||
let sourceBackingElements =
|
||||
array.shape.isEmpty ? 0 : array.strides[0].intValue * array.shape[0].intValue
|
||||
let byteOffset = offset * elementSize
|
||||
|
||||
guard byteOffset + bytesNeeded <= array.count * elementSize else {
|
||||
guard byteOffset + viewBackingElements * elementSize <= sourceBackingElements * elementSize else {
|
||||
throw ANEMemoryError.invalidShape
|
||||
}
|
||||
|
||||
@@ -144,34 +145,47 @@ public enum ANEMemoryUtils {
|
||||
|
||||
/// Stride-aware copy between two MLMultiArrays that may have different stride layouts.
|
||||
///
|
||||
/// Copies all logical elements from `source` to `destination` (which must have the same shape).
|
||||
/// The innermost dimension is copied as a contiguous block (stride-1), while outer dimensions
|
||||
/// Copies all logical elements from `source` to `destination` (which must have the same shape
|
||||
/// and data type). The innermost dimension must have stride 1 in both arrays. Outer dimensions
|
||||
/// are iterated respecting each array's strides.
|
||||
public static func strideAwareCopy(from source: MLMultiArray, to destination: MLMultiArray) {
|
||||
let shape = source.shape.map { $0.intValue }
|
||||
let srcStrides = source.strides.map { $0.intValue }
|
||||
let dstStrides = destination.strides.map { $0.intValue }
|
||||
|
||||
let srcPtr = source.dataPointer.assumingMemoryBound(to: Float.self)
|
||||
let dstPtr = destination.dataPointer.assumingMemoryBound(to: Float.self)
|
||||
|
||||
let ndim = shape.count
|
||||
guard ndim > 0 else { return }
|
||||
|
||||
// Validate shapes and types match
|
||||
precondition(
|
||||
source.shape == destination.shape,
|
||||
"strideAwareCopy: shape mismatch \(source.shape) vs \(destination.shape)"
|
||||
)
|
||||
precondition(
|
||||
source.dataType == destination.dataType,
|
||||
"strideAwareCopy: dataType mismatch"
|
||||
)
|
||||
precondition(
|
||||
srcStrides[ndim - 1] == 1 && dstStrides[ndim - 1] == 1,
|
||||
"strideAwareCopy: innermost stride must be 1"
|
||||
)
|
||||
|
||||
let elementSize = getElementSize(for: source.dataType)
|
||||
let srcPtr = source.dataPointer
|
||||
let dstPtr = destination.dataPointer
|
||||
|
||||
// If strides match, a single memcpy suffices (fast path).
|
||||
if srcStrides == dstStrides {
|
||||
// Total backing storage = first-dim stride * first-dim size
|
||||
let totalBacking = srcStrides[0] * shape[0]
|
||||
memcpy(dstPtr, srcPtr, totalBacking * MemoryLayout<Float>.size)
|
||||
let totalBytes = srcStrides[0] * shape[0] * elementSize
|
||||
memcpy(dstPtr, srcPtr, totalBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Innermost dimension length
|
||||
let innerLen = shape[ndim - 1]
|
||||
// Innermost dimension byte count
|
||||
let innerBytes = shape[ndim - 1] * elementSize
|
||||
|
||||
if ndim == 1 {
|
||||
// 1-D: just copy innerLen elements (both have stride 1 for innermost)
|
||||
memcpy(dstPtr, srcPtr, innerLen * MemoryLayout<Float>.size)
|
||||
memcpy(dstPtr, srcPtr, innerBytes)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -182,16 +196,16 @@ public enum ANEMemoryUtils {
|
||||
var indices = [Int](repeating: 0, count: ndim - 1)
|
||||
|
||||
for _ in 0..<outerCount {
|
||||
// Compute flat offset for source and destination
|
||||
var srcOffset = 0
|
||||
var dstOffset = 0
|
||||
// Compute flat byte offset for source and destination
|
||||
var srcByteOffset = 0
|
||||
var dstByteOffset = 0
|
||||
for d in 0..<(ndim - 1) {
|
||||
srcOffset += indices[d] * srcStrides[d]
|
||||
dstOffset += indices[d] * dstStrides[d]
|
||||
srcByteOffset += indices[d] * srcStrides[d] * elementSize
|
||||
dstByteOffset += indices[d] * dstStrides[d] * elementSize
|
||||
}
|
||||
|
||||
// Copy innermost dimension as contiguous block
|
||||
memcpy(dstPtr + dstOffset, srcPtr + srcOffset, innerLen * MemoryLayout<Float>.size)
|
||||
memcpy(dstPtr + dstByteOffset, srcPtr + srcByteOffset, innerBytes)
|
||||
|
||||
// Increment multi-index (odometer style)
|
||||
var carry = ndim - 2
|
||||
@@ -206,33 +220,4 @@ public enum ANEMemoryUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefetch memory pages for ANE processing
|
||||
public static func prefetchForANE(_ array: MLMultiArray) {
|
||||
let dataPointer = array.dataPointer
|
||||
let elementSize = getElementSize(for: array.dataType)
|
||||
let totalBytes = array.count * elementSize
|
||||
|
||||
// Touch first and last cache lines to trigger ANE DMA prefetch
|
||||
if totalBytes > 0 {
|
||||
_ = dataPointer.load(as: UInt8.self)
|
||||
if totalBytes > 1 {
|
||||
_ = dataPointer.advanced(by: totalBytes - 1).load(as: UInt8.self)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension for MLMultiArray to add ANE optimization methods
|
||||
extension MLMultiArray {
|
||||
|
||||
/// Check if this array is ANE-aligned
|
||||
public var isANEAligned: Bool {
|
||||
let address = Int(bitPattern: self.dataPointer)
|
||||
return address % ANEMemoryUtils.aneAlignment == 0
|
||||
}
|
||||
|
||||
/// Prefetch this array for ANE processing
|
||||
public func prefetchForANE() {
|
||||
ANEMemoryUtils.prefetchForANE(self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,23 +193,25 @@ final public class AudioConverter {
|
||||
throw AudioConverterError.failedToCreateSourceFormat
|
||||
}
|
||||
|
||||
// Use AVAudioConverter for channel mixing and format conversion
|
||||
// Use AVAudioConverter for channel mixing (same sample rate, no resampling needed)
|
||||
guard let converter = AVAudioConverter(from: format, to: monoFormat) else {
|
||||
throw AudioConverterError.failedToCreateConverter
|
||||
}
|
||||
configure(converter: converter)
|
||||
|
||||
guard let outputBuffer = AVAudioPCMBuffer(pcmFormat: monoFormat, frameCapacity: buffer.frameCapacity) else {
|
||||
throw AudioConverterError.failedToCreateBuffer
|
||||
}
|
||||
|
||||
nonisolated(unsafe) var provided = false
|
||||
nonisolated(unsafe) let capturedBuffer = buffer
|
||||
let provided = OSAllocatedUnfairLock(initialState: false)
|
||||
let inputBlock: AVAudioConverterInputBlock = { _, status in
|
||||
if !provided {
|
||||
provided = true
|
||||
let wasProvided = provided.withLock { state -> Bool in
|
||||
if state { return true }
|
||||
state = true
|
||||
return false
|
||||
}
|
||||
if !wasProvided {
|
||||
status.pointee = .haveData
|
||||
return capturedBuffer
|
||||
return buffer
|
||||
} else {
|
||||
status.pointee = .endOfStream
|
||||
return nil
|
||||
@@ -309,8 +311,12 @@ final public class AudioConverter {
|
||||
// but Swift 6 rejects mutation of captured vars in this callback.
|
||||
let provided = OSAllocatedUnfairLock(initialState: false)
|
||||
let inputBlock: AVAudioConverterInputBlock = { _, status in
|
||||
if !provided.withLock({ $0 }) {
|
||||
provided.withLock { $0 = true }
|
||||
let wasProvided = provided.withLock { state -> Bool in
|
||||
if state { return true }
|
||||
state = true
|
||||
return false
|
||||
}
|
||||
if !wasProvided {
|
||||
status.pointee = .haveData
|
||||
return buffer
|
||||
} else {
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ import Foundation
|
||||
/// - center: True with pad_mode='constant' (zero padding)
|
||||
/// - normalize: "NA" (no normalization)
|
||||
/// - dither: 0.0 (disabled for determinism)
|
||||
public final class NeMoMelSpectrogram {
|
||||
public final class AudioMelSpectrogram {
|
||||
public enum PaddingMode: Sendable {
|
||||
case center
|
||||
case prePadded
|
||||
@@ -66,94 +66,11 @@ public enum PocketTtsResourceDownloader {
|
||||
|
||||
/// Download the Mimi encoder model files from HuggingFace.
|
||||
private static func downloadMimiEncoder(to repoDir: URL) async throws {
|
||||
let modelName = ModelNames.PocketTTS.mimiEncoderFile
|
||||
let modelDir = repoDir.appendingPathComponent(modelName)
|
||||
|
||||
try FileManager.default.createDirectory(at: modelDir, withIntermediateDirectories: true)
|
||||
|
||||
// List files in the mimi_encoder.mlmodelc directory
|
||||
let apiPath = "tree/main/\(modelName)"
|
||||
let dirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, apiPath)
|
||||
|
||||
let (dirData, _) = try await DownloadUtils.fetchWithAuth(from: dirURL)
|
||||
|
||||
guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else {
|
||||
throw PocketTTSError.downloadFailed("Failed to list Mimi encoder files")
|
||||
}
|
||||
|
||||
// Collect all files recursively
|
||||
var filesToDownload: [(path: String, size: Int)] = []
|
||||
|
||||
func collectFiles(from items: [[String: Any]], basePath: String) async throws {
|
||||
for item in items {
|
||||
guard let itemPath = item["path"] as? String,
|
||||
let itemType = item["type"] as? String
|
||||
else { continue }
|
||||
|
||||
if itemType == "directory" {
|
||||
let subDirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, "tree/main/\(itemPath)")
|
||||
let (subDirData, _) = try await DownloadUtils.fetchWithAuth(from: subDirURL)
|
||||
if let subItems = try JSONSerialization.jsonObject(with: subDirData) as? [[String: Any]] {
|
||||
try await collectFiles(from: subItems, basePath: itemPath)
|
||||
}
|
||||
} else if itemType == "file" {
|
||||
let fileSize = item["size"] as? Int ?? -1
|
||||
filesToDownload.append((path: itemPath, size: fileSize))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try await collectFiles(from: items, basePath: modelName)
|
||||
logger.info("Found \(filesToDownload.count) files in Mimi encoder")
|
||||
|
||||
// Download each file
|
||||
for (index, file) in filesToDownload.enumerated() {
|
||||
// Local path relative to modelName
|
||||
let relativePath =
|
||||
file.path.hasPrefix("\(modelName)/")
|
||||
? String(file.path.dropFirst(modelName.count + 1))
|
||||
: file.path
|
||||
let destPath = modelDir.appendingPathComponent(relativePath)
|
||||
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
continue
|
||||
}
|
||||
|
||||
try FileManager.default.createDirectory(
|
||||
at: destPath.deletingLastPathComponent(),
|
||||
withIntermediateDirectories: true
|
||||
)
|
||||
|
||||
// Handle empty files
|
||||
if file.size == 0 {
|
||||
FileManager.default.createFile(atPath: destPath.path, contents: Data())
|
||||
continue
|
||||
}
|
||||
|
||||
let encodedPath = file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
|
||||
let fileURL = try ModelRegistry.resolveModel(Repo.pocketTts.remotePath, encodedPath)
|
||||
|
||||
let (tempURL, response) = try await DownloadUtils.sharedSession.download(
|
||||
for: URLRequest(url: fileURL, timeoutInterval: 1800)
|
||||
)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw PocketTTSError.downloadFailed("Failed to download \(file.path)")
|
||||
}
|
||||
|
||||
if FileManager.default.fileExists(atPath: destPath.path) {
|
||||
try? FileManager.default.removeItem(at: destPath)
|
||||
}
|
||||
try FileManager.default.moveItem(at: tempURL, to: destPath)
|
||||
|
||||
if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 {
|
||||
logger.info("Downloaded \(index + 1)/\(filesToDownload.count) Mimi encoder files")
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Mimi encoder download complete")
|
||||
try await DownloadUtils.downloadSubdirectory(
|
||||
.pocketTts,
|
||||
subdirectory: ModelNames.PocketTTS.mimiEncoderFile,
|
||||
to: repoDir
|
||||
)
|
||||
}
|
||||
|
||||
/// Ensure constants (binary blobs + tokenizer) are available.
|
||||
|
||||
@@ -180,9 +180,7 @@ enum StreamDiarizationBenchmark {
|
||||
printUsage()
|
||||
return
|
||||
default:
|
||||
if !arguments[i].starts(with: "--") {
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
#if os(macOS)
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
/// Shared utilities for diarization benchmark commands (LS-EEND and Sortformer).
|
||||
enum DiarizationBenchmarkUtils {
|
||||
|
||||
/// Dataset corpora supported by diarization benchmarks.
|
||||
enum Dataset: String {
|
||||
case ami = "ami"
|
||||
case voxconverse = "voxconverse"
|
||||
case callhome = "callhome"
|
||||
}
|
||||
|
||||
/// Per-meeting benchmark result shared across diarization benchmark commands.
|
||||
struct BenchmarkResult {
|
||||
let meetingName: String
|
||||
let der: Float
|
||||
let missRate: Float
|
||||
let falseAlarmRate: Float
|
||||
let speakerErrorRate: Float
|
||||
let rtfx: Float
|
||||
let processingTime: Double
|
||||
let totalFrames: Int
|
||||
let detectedSpeakers: Int
|
||||
let groundTruthSpeakers: Int
|
||||
let modelLoadTime: Double
|
||||
let audioLoadTime: Double?
|
||||
}
|
||||
|
||||
// MARK: - File Paths
|
||||
|
||||
static func getAMIFiles(maxFiles: Int?) -> [String] {
|
||||
let allMeetings = [
|
||||
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
|
||||
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
|
||||
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
|
||||
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
|
||||
]
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for meeting in allMeetings {
|
||||
let path = getAudioPath(for: meeting, dataset: .ami)
|
||||
if FileManager.default.fileExists(atPath: path) {
|
||||
availableMeetings.append(meeting)
|
||||
}
|
||||
}
|
||||
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
|
||||
).path
|
||||
case .voxconverse:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
|
||||
).path
|
||||
case .callhome:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
|
||||
).path
|
||||
}
|
||||
}
|
||||
|
||||
static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/ami_official/rttm/\(meeting).rttm"
|
||||
)
|
||||
case .voxconverse:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm"
|
||||
)
|
||||
case .callhome:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let voxDir = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
|
||||
)
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: voxDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
let rttmPath = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
|
||||
)
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
availableMeetings.sort()
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: callhomeDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
availableMeetings.sort()
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
/// Returns files for the given dataset, filtering by availability.
|
||||
static func getFiles(for dataset: Dataset, maxFiles: Int?) -> [String] {
|
||||
switch dataset {
|
||||
case .ami:
|
||||
return getAMIFiles(maxFiles: maxFiles)
|
||||
case .voxconverse:
|
||||
return getVoxConverseFiles(maxFiles: maxFiles)
|
||||
case .callhome:
|
||||
return getCALLHOMEFiles(maxFiles: maxFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Summary & Output
|
||||
|
||||
/// Prints a formatted benchmark summary table.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - results: Benchmark results to summarize.
|
||||
/// - title: Header title (e.g. "LS-EEND BENCHMARK SUMMARY").
|
||||
/// - derTargets: DER percentage thresholds to check, ordered from strictest to most lenient
|
||||
/// (e.g. `[15, 25]` prints "DER < 15%" if met, else "DER < 25%", else "DER > 25%").
|
||||
static func printFinalSummary(
|
||||
results: [BenchmarkResult],
|
||||
title: String,
|
||||
derTargets: [Float]
|
||||
) {
|
||||
guard !results.isEmpty else { return }
|
||||
|
||||
print("\n" + String(repeating: "=", count: 80))
|
||||
print(title)
|
||||
print(String(repeating: "=", count: 80))
|
||||
|
||||
print("Results Sorted by DER:")
|
||||
print(String(repeating: "-", count: 70))
|
||||
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
for result in results.sorted(by: { $0.der < $1.der }) {
|
||||
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
|
||||
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
|
||||
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
|
||||
print(
|
||||
String(
|
||||
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
|
||||
meetingCol,
|
||||
result.der,
|
||||
result.missRate,
|
||||
result.falseAlarmRate,
|
||||
result.speakerErrorRate,
|
||||
speakerCol,
|
||||
result.rtfx))
|
||||
}
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
let count = Float(results.count)
|
||||
let avgDER = results.map { $0.der }.reduce(0, +) / count
|
||||
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
|
||||
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
|
||||
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
|
||||
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
|
||||
|
||||
print(
|
||||
String(
|
||||
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
|
||||
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
|
||||
print(String(repeating: "=", count: 70))
|
||||
|
||||
print("\nTarget Check:")
|
||||
var matched = false
|
||||
for target in derTargets.sorted() {
|
||||
if avgDER < target {
|
||||
print(" DER < \(String(format: "%.0f", target))% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched, let highest = derTargets.max() {
|
||||
print(
|
||||
" DER > \(String(format: "%.0f", highest))% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
}
|
||||
|
||||
if avgRTFx > 1 {
|
||||
print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
} else {
|
||||
print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
}
|
||||
}
|
||||
|
||||
static func saveJSONResults(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { resultToDict($0) }
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
print("JSON results saved to: \(path)")
|
||||
} catch {
|
||||
print("Failed to save JSON: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Progress Save/Load
|
||||
|
||||
static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
|
||||
var dict: [String: Any] = [
|
||||
"meeting": result.meetingName,
|
||||
"der": result.der,
|
||||
"missRate": result.missRate,
|
||||
"falseAlarmRate": result.falseAlarmRate,
|
||||
"speakerErrorRate": result.speakerErrorRate,
|
||||
"rtfx": result.rtfx,
|
||||
"processingTime": result.processingTime,
|
||||
"totalFrames": result.totalFrames,
|
||||
"detectedSpeakers": result.detectedSpeakers,
|
||||
"groundTruthSpeakers": result.groundTruthSpeakers,
|
||||
"modelLoadTime": result.modelLoadTime,
|
||||
]
|
||||
if let audioLoadTime = result.audioLoadTime {
|
||||
dict["audioLoadTime"] = audioLoadTime
|
||||
}
|
||||
return dict
|
||||
}
|
||||
|
||||
static func saveProgress(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { resultToDict($0) }
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
} catch {
|
||||
print("Failed to save progress: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
static func loadProgress(from path: String) -> [BenchmarkResult]? {
|
||||
guard FileManager.default.fileExists(atPath: path) else { return nil }
|
||||
|
||||
do {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
||||
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonArray.compactMap { dict -> BenchmarkResult? in
|
||||
guard let meeting = dict["meeting"] as? String,
|
||||
let der = (dict["der"] as? NSNumber)?.floatValue,
|
||||
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
|
||||
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
|
||||
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
|
||||
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
|
||||
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
|
||||
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
|
||||
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
|
||||
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
|
||||
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
|
||||
|
||||
return BenchmarkResult(
|
||||
meetingName: meeting,
|
||||
der: der,
|
||||
missRate: missRate,
|
||||
falseAlarmRate: falseAlarmRate,
|
||||
speakerErrorRate: speakerErrorRate,
|
||||
rtfx: rtfx,
|
||||
processingTime: processingTime,
|
||||
totalFrames: totalFrames,
|
||||
detectedSpeakers: detectedSpeakers,
|
||||
groundTruthSpeakers: groundTruthSpeakers,
|
||||
modelLoadTime: modelLoadTime,
|
||||
audioLoadTime: audioLoadTime
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
print("Failed to load progress: \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -1,5 +1,4 @@
|
||||
#if os(macOS)
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
@@ -7,26 +6,8 @@ import Foundation
|
||||
enum LSEENDBenchmark {
|
||||
private static let logger = AppLogger(category: "LSEENDBench")
|
||||
|
||||
enum Dataset: String {
|
||||
case ami = "ami"
|
||||
case voxconverse = "voxconverse"
|
||||
case callhome = "callhome"
|
||||
}
|
||||
|
||||
struct BenchmarkResult {
|
||||
let meetingName: String
|
||||
let der: Float
|
||||
let missRate: Float
|
||||
let falseAlarmRate: Float
|
||||
let speakerErrorRate: Float
|
||||
let rtfx: Float
|
||||
let processingTime: Double
|
||||
let totalFrames: Int
|
||||
let detectedSpeakers: Int
|
||||
let groundTruthSpeakers: Int
|
||||
let modelLoadTime: Double
|
||||
let audioLoadTime: Double
|
||||
}
|
||||
typealias Dataset = DiarizationBenchmarkUtils.Dataset
|
||||
typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult
|
||||
|
||||
static func printUsage() {
|
||||
print(
|
||||
@@ -197,9 +178,7 @@ enum LSEENDBenchmark {
|
||||
printUsage()
|
||||
return
|
||||
default:
|
||||
if !arguments[i].starts(with: "--") {
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
@@ -228,14 +207,7 @@ enum LSEENDBenchmark {
|
||||
if let meeting = singleFile {
|
||||
filesToProcess = [meeting]
|
||||
} else {
|
||||
switch dataset {
|
||||
case .ami:
|
||||
filesToProcess = getAMIFiles(maxFiles: maxFiles)
|
||||
case .voxconverse:
|
||||
filesToProcess = getVoxConverseFiles(maxFiles: maxFiles)
|
||||
case .callhome:
|
||||
filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles)
|
||||
}
|
||||
filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles)
|
||||
}
|
||||
|
||||
if filesToProcess.isEmpty {
|
||||
@@ -252,7 +224,7 @@ enum LSEENDBenchmark {
|
||||
var completedResults: [BenchmarkResult] = []
|
||||
var completedMeetings: Set<String> = []
|
||||
if resumeFromProgress {
|
||||
if let loaded = loadProgress(from: progressFile) {
|
||||
if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) {
|
||||
completedResults = loaded
|
||||
completedMeetings = Set(loaded.map { $0.meetingName })
|
||||
print("Resuming: loaded \(completedResults.count) previous results")
|
||||
@@ -339,7 +311,7 @@ enum LSEENDBenchmark {
|
||||
print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth")
|
||||
|
||||
// Save progress after each file
|
||||
saveProgress(results: allResults, to: progressFile)
|
||||
DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile)
|
||||
print("Progress saved (\(allResults.count) files complete)")
|
||||
}
|
||||
fflush(stdout)
|
||||
@@ -349,11 +321,15 @@ enum LSEENDBenchmark {
|
||||
}
|
||||
|
||||
// Print final summary
|
||||
printFinalSummary(results: allResults)
|
||||
DiarizationBenchmarkUtils.printFinalSummary(
|
||||
results: allResults,
|
||||
title: "LS-EEND BENCHMARK SUMMARY",
|
||||
derTargets: [15, 25]
|
||||
)
|
||||
|
||||
// Save results
|
||||
if let outputPath = outputFile {
|
||||
saveJSONResults(results: allResults, to: outputPath)
|
||||
DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,7 +345,7 @@ enum LSEENDBenchmark {
|
||||
numSpeakers: Int,
|
||||
verbose: Bool
|
||||
) async -> BenchmarkResult? {
|
||||
let audioPath = getAudioPath(for: meetingName, dataset: dataset)
|
||||
let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset)
|
||||
guard FileManager.default.fileExists(atPath: audioPath) else {
|
||||
print("Audio file not found: \(audioPath)")
|
||||
fflush(stdout)
|
||||
@@ -378,10 +354,7 @@ enum LSEENDBenchmark {
|
||||
|
||||
do {
|
||||
// Load and process audio
|
||||
let startLoadTime = Date()
|
||||
let audioURL = URL(fileURLWithPath: audioPath)
|
||||
let audioLoadTime = Date().timeIntervalSince(startLoadTime)
|
||||
|
||||
let startTime = Date()
|
||||
let timeline = try diarizer.processComplete(audioFileURL: audioURL)
|
||||
let processingTime = Date().timeIntervalSince(startTime)
|
||||
@@ -400,7 +373,7 @@ enum LSEENDBenchmark {
|
||||
let rttmEntries: [LSEENDRTTMEntry]
|
||||
let rttmSpeakers: [String]
|
||||
|
||||
let rttmURL = getRTTMURL(for: meetingName, dataset: dataset)
|
||||
let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset)
|
||||
if let rttmURL = rttmURL, FileManager.default.fileExists(atPath: rttmURL.path) {
|
||||
let parsed = try LSEENDEvaluation.parseRTTM(url: rttmURL)
|
||||
rttmEntries = parsed.entries
|
||||
@@ -413,7 +386,7 @@ enum LSEENDBenchmark {
|
||||
duration: duration
|
||||
)
|
||||
guard !groundTruth.isEmpty else {
|
||||
print("⚠️ No ground truth found for \(meetingName)")
|
||||
print("No ground truth found for \(meetingName)")
|
||||
return nil
|
||||
}
|
||||
// Convert TimedSpeakerSegment to LSEENDRTTMEntry
|
||||
@@ -507,7 +480,7 @@ enum LSEENDBenchmark {
|
||||
detectedSpeakers: detectedSpeakerIndices.count,
|
||||
groundTruthSpeakers: rttmSpeakers.count,
|
||||
modelLoadTime: modelLoadTime,
|
||||
audioLoadTime: audioLoadTime
|
||||
audioLoadTime: nil
|
||||
)
|
||||
|
||||
} catch {
|
||||
@@ -516,279 +489,5 @@ enum LSEENDBenchmark {
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - File Paths
|
||||
|
||||
private static func getAMIFiles(maxFiles: Int?) -> [String] {
|
||||
let allMeetings = [
|
||||
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
|
||||
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
|
||||
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
|
||||
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
|
||||
]
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for meeting in allMeetings {
|
||||
let path = getAudioPath(for: meeting, dataset: .ami)
|
||||
if FileManager.default.fileExists(atPath: path) {
|
||||
availableMeetings.append(meeting)
|
||||
}
|
||||
}
|
||||
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
private static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
|
||||
).path
|
||||
case .voxconverse:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
|
||||
).path
|
||||
case .callhome:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
|
||||
).path
|
||||
}
|
||||
}
|
||||
|
||||
private static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
// Try local RTTM first, then fall back to dataset directory
|
||||
let localPath = "Streaming-Sortformer-Conversion/\(meeting).rttm"
|
||||
if FileManager.default.fileExists(atPath: localPath) {
|
||||
return URL(fileURLWithPath: localPath)
|
||||
}
|
||||
let datasetPath = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/ami_official/rttm/\(meeting).rttm"
|
||||
)
|
||||
return datasetPath
|
||||
case .voxconverse:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm"
|
||||
)
|
||||
case .callhome:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let voxDir = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
|
||||
)
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: voxDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
let rttmPath = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
|
||||
)
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
availableMeetings.sort()
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: callhomeDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
availableMeetings.sort()
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
// MARK: - Summary & Output
|
||||
|
||||
private static func printFinalSummary(results: [BenchmarkResult]) {
|
||||
guard !results.isEmpty else { return }
|
||||
|
||||
print("\n" + String(repeating: "=", count: 80))
|
||||
print("LS-EEND BENCHMARK SUMMARY")
|
||||
print(String(repeating: "=", count: 80))
|
||||
|
||||
print("Results Sorted by DER:")
|
||||
print(String(repeating: "-", count: 70))
|
||||
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
for result in results.sorted(by: { $0.der < $1.der }) {
|
||||
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
|
||||
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
|
||||
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
|
||||
print(
|
||||
String(
|
||||
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
|
||||
meetingCol,
|
||||
result.der,
|
||||
result.missRate,
|
||||
result.falseAlarmRate,
|
||||
result.speakerErrorRate,
|
||||
speakerCol,
|
||||
result.rtfx))
|
||||
}
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
let count = Float(results.count)
|
||||
let avgDER = results.map { $0.der }.reduce(0, +) / count
|
||||
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
|
||||
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
|
||||
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
|
||||
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
|
||||
|
||||
print(
|
||||
String(
|
||||
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
|
||||
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
|
||||
print(String(repeating: "=", count: 70))
|
||||
|
||||
print("\nTarget Check:")
|
||||
if avgDER < 15 {
|
||||
print(" DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
} else if avgDER < 25 {
|
||||
print(" DER < 25% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
} else {
|
||||
print(" DER > 25% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
}
|
||||
|
||||
if avgRTFx > 1 {
|
||||
print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
} else {
|
||||
print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func saveJSONResults(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { resultToDict($0) }
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
print("JSON results saved to: \(path)")
|
||||
} catch {
|
||||
print("Failed to save JSON: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Progress Save/Load
|
||||
|
||||
private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
|
||||
return [
|
||||
"meeting": result.meetingName,
|
||||
"der": result.der,
|
||||
"missRate": result.missRate,
|
||||
"falseAlarmRate": result.falseAlarmRate,
|
||||
"speakerErrorRate": result.speakerErrorRate,
|
||||
"rtfx": result.rtfx,
|
||||
"processingTime": result.processingTime,
|
||||
"totalFrames": result.totalFrames,
|
||||
"detectedSpeakers": result.detectedSpeakers,
|
||||
"groundTruthSpeakers": result.groundTruthSpeakers,
|
||||
"modelLoadTime": result.modelLoadTime,
|
||||
"audioLoadTime": result.audioLoadTime,
|
||||
]
|
||||
}
|
||||
|
||||
private static func saveProgress(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { resultToDict($0) }
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
} catch {
|
||||
print("Failed to save progress: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func loadProgress(from path: String) -> [BenchmarkResult]? {
|
||||
guard FileManager.default.fileExists(atPath: path) else { return nil }
|
||||
|
||||
do {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
||||
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonArray.compactMap { dict -> BenchmarkResult? in
|
||||
guard let meeting = dict["meeting"] as? String,
|
||||
let der = (dict["der"] as? NSNumber)?.floatValue,
|
||||
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
|
||||
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
|
||||
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
|
||||
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
|
||||
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
|
||||
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
|
||||
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
|
||||
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
|
||||
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue,
|
||||
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return BenchmarkResult(
|
||||
meetingName: meeting,
|
||||
der: der,
|
||||
missRate: missRate,
|
||||
falseAlarmRate: falseAlarmRate,
|
||||
speakerErrorRate: speakerErrorRate,
|
||||
rtfx: rtfx,
|
||||
processingTime: processingTime,
|
||||
totalFrames: totalFrames,
|
||||
detectedSpeakers: detectedSpeakers,
|
||||
groundTruthSpeakers: groundTruthSpeakers,
|
||||
modelLoadTime: modelLoadTime,
|
||||
audioLoadTime: audioLoadTime
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
print("Failed to load progress: \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#if os(macOS)
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
@@ -20,8 +19,6 @@ enum LSEENDCommand {
|
||||
var outputFile: String?
|
||||
var variant: LSEENDVariant = .dihard3
|
||||
var threshold: Float = 0.5
|
||||
var medianWidth: Int = 1
|
||||
var collarSeconds: Double = 0.25
|
||||
|
||||
// Post-processing parameters
|
||||
var onset: Float?
|
||||
@@ -62,16 +59,6 @@ enum LSEENDCommand {
|
||||
threshold = v
|
||||
i += 1
|
||||
}
|
||||
case "--median-width":
|
||||
if i + 1 < arguments.count, let v = Int(arguments[i + 1]) {
|
||||
medianWidth = v
|
||||
i += 1
|
||||
}
|
||||
case "--collar":
|
||||
if i + 1 < arguments.count, let v = Double(arguments[i + 1]) {
|
||||
collarSeconds = v
|
||||
i += 1
|
||||
}
|
||||
case "--onset":
|
||||
if i + 1 < arguments.count, let v = Float(arguments[i + 1]) {
|
||||
onset = v
|
||||
@@ -244,7 +231,7 @@ enum LSEENDCommand {
|
||||
}
|
||||
|
||||
private static func printUsage() {
|
||||
logger.info(
|
||||
print(
|
||||
"""
|
||||
|
||||
LS-EEND Command Usage:
|
||||
@@ -253,8 +240,6 @@ enum LSEENDCommand {
|
||||
Options:
|
||||
--variant <name> Model variant: ami, callhome, dihard2, dihard3 (default: dihard3)
|
||||
--threshold <value> Speaker activity threshold (default: 0.5)
|
||||
--median-width <value> Median filter width for post-processing (default: 1)
|
||||
--collar <value> Collar duration in seconds for evaluation (default: 0.25)
|
||||
--onset <value> Onset threshold for speech detection (default: 0.5)
|
||||
--offset <value> Offset threshold for speech detection (default: 0.5)
|
||||
--pad-onset <value> Padding before speech segments in seconds
|
||||
@@ -273,8 +258,7 @@ enum LSEENDCommand {
|
||||
|
||||
# Save results to file
|
||||
fluidaudio lseend audio.wav --output results.json
|
||||
"""
|
||||
)
|
||||
""")
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -85,7 +85,7 @@ public enum LSEENDEvaluation {
|
||||
var speakers: [String] = []
|
||||
for line in text.split(whereSeparator: \.isNewline) {
|
||||
let parts = line.split(separator: " ")
|
||||
guard parts.count >= 8 else { continue }
|
||||
guard parts.count >= 8, parts[0] == "SPEAKER" else { continue }
|
||||
let speaker = String(parts[7])
|
||||
if !speakers.contains(speaker) {
|
||||
speakers.append(speaker)
|
||||
@@ -444,6 +444,7 @@ public enum LSEENDEvaluation {
|
||||
}
|
||||
|
||||
private static func solveAssignmentRowsToColumns(cost: [Float], rows: Int, columns: Int) -> [Int] {
|
||||
precondition(columns <= 20, "Assignment solver is O(2^columns); columns=\(columns) is too large")
|
||||
let stateCount = 1 << columns
|
||||
var dp = [Float](repeating: .greatestFiniteMagnitude, count: stateCount)
|
||||
var parent = [Int](repeating: -1, count: stateCount)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#if os(macOS)
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
@@ -7,26 +6,8 @@ import Foundation
|
||||
enum SortformerBenchmark {
|
||||
private static let logger = AppLogger(category: "SortformerBench")
|
||||
|
||||
enum Dataset: String {
|
||||
case ami = "ami"
|
||||
case voxconverse = "voxconverse"
|
||||
case callhome = "callhome"
|
||||
}
|
||||
|
||||
struct BenchmarkResult {
|
||||
let meetingName: String
|
||||
let der: Float
|
||||
let missRate: Float
|
||||
let falseAlarmRate: Float
|
||||
let speakerErrorRate: Float
|
||||
let rtfx: Float
|
||||
let processingTime: Double
|
||||
let totalFrames: Int
|
||||
let detectedSpeakers: Int
|
||||
let groundTruthSpeakers: Int
|
||||
let modelLoadTime: Double
|
||||
let audioLoadTime: Double
|
||||
}
|
||||
typealias Dataset = DiarizationBenchmarkUtils.Dataset
|
||||
typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult
|
||||
|
||||
static func printUsage() {
|
||||
print(
|
||||
@@ -42,7 +23,6 @@ enum SortformerBenchmark {
|
||||
--single-file <name> Process a specific meeting (e.g., ES2004a)
|
||||
--max-files <n> Maximum number of files to process
|
||||
--threshold <value> Speaker activity threshold (default: 0.5)
|
||||
--preprocessor <path> Path to SortformerPreprocessor.mlpackage
|
||||
--model <path> Path to Sortformer.mlpackage
|
||||
--nvidia-low-latency Use NVIDIA 1.04s latency config (20.57% DER target)
|
||||
--nvidia-high-latency Use NVIDIA 30.4s latency config (20.57% DER target)
|
||||
@@ -102,7 +82,7 @@ enum SortformerBenchmark {
|
||||
if let d = Dataset(rawValue: arguments[i + 1].lowercased()) {
|
||||
dataset = d
|
||||
} else {
|
||||
print("⚠️ Unknown dataset: \(arguments[i + 1]). Using ami.")
|
||||
print("Unknown dataset: \(arguments[i + 1]). Using ami.")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
@@ -158,9 +138,7 @@ enum SortformerBenchmark {
|
||||
printUsage()
|
||||
return
|
||||
default:
|
||||
if !arguments[i].starts(with: "--") {
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
logger.warning("Unknown argument: \(arguments[i])")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
@@ -170,7 +148,7 @@ enum SortformerBenchmark {
|
||||
useHuggingFace = true
|
||||
}
|
||||
|
||||
print("🚀 Starting Sortformer Benchmark")
|
||||
print("Starting Sortformer Benchmark")
|
||||
fflush(stdout)
|
||||
print(" Dataset: \(dataset.rawValue)")
|
||||
print(" Threshold: \(threshold)")
|
||||
@@ -209,7 +187,7 @@ enum SortformerBenchmark {
|
||||
|
||||
// Download dataset if needed
|
||||
if autoDownload && dataset == .ami {
|
||||
print("📥 Downloading AMI dataset if needed...")
|
||||
print("Downloading AMI dataset if needed...")
|
||||
await DatasetDownloader.downloadAMIDataset(
|
||||
variant: .sdm,
|
||||
force: false,
|
||||
@@ -223,23 +201,16 @@ enum SortformerBenchmark {
|
||||
if let meeting = singleFile {
|
||||
filesToProcess = [meeting]
|
||||
} else {
|
||||
switch dataset {
|
||||
case .ami:
|
||||
filesToProcess = getAMIFiles(maxFiles: maxFiles)
|
||||
case .voxconverse:
|
||||
filesToProcess = getVoxConverseFiles(maxFiles: maxFiles)
|
||||
case .callhome:
|
||||
filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles)
|
||||
}
|
||||
filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles)
|
||||
}
|
||||
|
||||
if filesToProcess.isEmpty {
|
||||
print("❌ No files found to process")
|
||||
print("No files found to process")
|
||||
fflush(stdout)
|
||||
return
|
||||
}
|
||||
|
||||
print("📂 Processing \(filesToProcess.count) file(s)")
|
||||
print("Processing \(filesToProcess.count) file(s)")
|
||||
print(" Progress file: \(progressFile)")
|
||||
fflush(stdout)
|
||||
|
||||
@@ -247,29 +218,29 @@ enum SortformerBenchmark {
|
||||
var completedResults: [BenchmarkResult] = []
|
||||
var completedMeetings: Set<String> = []
|
||||
if resumeFromProgress {
|
||||
if let loaded = loadProgress(from: progressFile) {
|
||||
if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) {
|
||||
completedResults = loaded
|
||||
completedMeetings = Set(loaded.map { $0.meetingName })
|
||||
print("📥 Resuming: loaded \(completedResults.count) previous results")
|
||||
print("Resuming: loaded \(completedResults.count) previous results")
|
||||
for result in completedResults {
|
||||
print(" ✅ \(result.meetingName): \(String(format: "%.1f", result.der))% DER")
|
||||
print(" \(result.meetingName): \(String(format: "%.1f", result.der))% DER")
|
||||
}
|
||||
} else {
|
||||
print("📥 No previous progress found, starting fresh")
|
||||
print("No previous progress found, starting fresh")
|
||||
}
|
||||
}
|
||||
print("")
|
||||
fflush(stdout)
|
||||
|
||||
// Initialize Sortformer
|
||||
print("🔧 Loading Sortformer models...")
|
||||
print("Loading Sortformer models...")
|
||||
fflush(stdout)
|
||||
let modelLoadStart = Date()
|
||||
var config: SortformerConfig
|
||||
if useNvidiaHighLatency {
|
||||
config = SortformerConfig.nvidiaHighLatencyV2_1
|
||||
config = SortformerConfig.highContextV2_1
|
||||
} else if useNvidiaLowLatency {
|
||||
config = SortformerConfig.nvidiaLowLatencyV2_1
|
||||
config = SortformerConfig.balancedV2_1
|
||||
} else {
|
||||
config = SortformerConfig.default
|
||||
}
|
||||
@@ -287,12 +258,12 @@ enum SortformerBenchmark {
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
print("❌ Failed to initialize Sortformer: \(error)")
|
||||
print("Failed to initialize Sortformer: \(error)")
|
||||
return
|
||||
}
|
||||
|
||||
let modelLoadTime = Date().timeIntervalSince(modelLoadStart)
|
||||
print("✅ Models loaded in \(String(format: "%.2f", modelLoadTime))s\n")
|
||||
print("Models loaded in \(String(format: "%.2f", modelLoadTime))s\n")
|
||||
fflush(stdout)
|
||||
|
||||
// Process each file
|
||||
@@ -324,14 +295,14 @@ enum SortformerBenchmark {
|
||||
allResults.append(result)
|
||||
|
||||
// Print summary
|
||||
print("📊 Results for \(meetingName):")
|
||||
print("Results for \(meetingName):")
|
||||
print(" DER: \(String(format: "%.1f", result.der))%")
|
||||
print(" RTFx: \(String(format: "%.1f", result.rtfx))x")
|
||||
print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth")
|
||||
|
||||
// Save progress after each file
|
||||
saveProgress(results: allResults, to: progressFile)
|
||||
print("💾 Progress saved (\(allResults.count) files complete)")
|
||||
DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile)
|
||||
print("Progress saved (\(allResults.count) files complete)")
|
||||
}
|
||||
fflush(stdout)
|
||||
|
||||
@@ -340,11 +311,15 @@ enum SortformerBenchmark {
|
||||
}
|
||||
|
||||
// Print final summary
|
||||
printFinalSummary(results: allResults)
|
||||
DiarizationBenchmarkUtils.printFinalSummary(
|
||||
results: allResults,
|
||||
title: "SORTFORMER BENCHMARK SUMMARY",
|
||||
derTargets: [15, 20]
|
||||
)
|
||||
|
||||
// Save results
|
||||
if let outputPath = outputFile {
|
||||
saveJSONResults(results: allResults, to: outputPath)
|
||||
DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,9 +332,9 @@ enum SortformerBenchmark {
|
||||
verbose: Bool
|
||||
) async -> BenchmarkResult? {
|
||||
|
||||
let audioPath = getAudioPath(for: meetingName, dataset: dataset)
|
||||
let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset)
|
||||
guard FileManager.default.fileExists(atPath: audioPath) else {
|
||||
print("❌ Audio file not found: \(audioPath)")
|
||||
print("Audio file not found: \(audioPath)")
|
||||
fflush(stdout)
|
||||
return nil
|
||||
}
|
||||
@@ -441,7 +416,7 @@ enum SortformerBenchmark {
|
||||
}
|
||||
|
||||
guard !groundTruth.isEmpty else {
|
||||
print("⚠️ No ground truth found for \(meetingName)")
|
||||
print("No ground truth found for \(meetingName)")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -453,7 +428,7 @@ enum SortformerBenchmark {
|
||||
let simpleMetrics = calculateSimpleDER(
|
||||
predictions: filteredPredictions,
|
||||
numFrames: result.numFinalizedFrames,
|
||||
numSpeakers: 4,
|
||||
numSpeakers: result.config.numSpeakers,
|
||||
groundTruth: groundTruth,
|
||||
threshold: threshold,
|
||||
frameShift: 0.08 // 80ms frames like NeMo
|
||||
@@ -490,267 +465,7 @@ enum SortformerBenchmark {
|
||||
)
|
||||
|
||||
} catch {
|
||||
print("❌ Error processing \(meetingName): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private static func getAMIFiles(maxFiles: Int?) -> [String] {
|
||||
// Official AMI SDM test set (16 meetings) - matches NeMo evaluation
|
||||
let allMeetings = [
|
||||
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
|
||||
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
|
||||
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
|
||||
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
|
||||
]
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for meeting in allMeetings {
|
||||
let path = getAudioPath(for: meeting, dataset: .ami)
|
||||
if FileManager.default.fileExists(atPath: path) {
|
||||
availableMeetings.append(meeting)
|
||||
}
|
||||
}
|
||||
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
private static func getAudioPath(for meeting: String, dataset: Dataset) -> String {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav"
|
||||
).path
|
||||
case .voxconverse:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav"
|
||||
).path
|
||||
case .callhome:
|
||||
return homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/\(meeting).wav"
|
||||
).path
|
||||
}
|
||||
}
|
||||
|
||||
private static func getVoxConverseFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let voxDir = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/voxconverse_test_wav"
|
||||
)
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: voxDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
// Check that RTTM file exists
|
||||
let rttmPath = homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm"
|
||||
)
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort alphabetically for reproducibility
|
||||
availableMeetings.sort()
|
||||
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] {
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng")
|
||||
|
||||
guard
|
||||
let files = try? FileManager.default.contentsOfDirectory(
|
||||
at: callhomeDir,
|
||||
includingPropertiesForKeys: nil
|
||||
)
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
var availableMeetings: [String] = []
|
||||
for file in files where file.pathExtension == "wav" {
|
||||
let name = file.deletingPathExtension().lastPathComponent
|
||||
// Check that RTTM file exists
|
||||
let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm")
|
||||
if FileManager.default.fileExists(atPath: rttmPath.path) {
|
||||
availableMeetings.append(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort alphabetically for reproducibility
|
||||
availableMeetings.sort()
|
||||
|
||||
if let max = maxFiles {
|
||||
return Array(availableMeetings.prefix(max))
|
||||
}
|
||||
|
||||
return availableMeetings
|
||||
}
|
||||
|
||||
private static func printFinalSummary(results: [BenchmarkResult]) {
|
||||
guard !results.isEmpty else { return }
|
||||
|
||||
print("\n" + String(repeating: "=", count: 80))
|
||||
print("SORTFORMER BENCHMARK SUMMARY")
|
||||
print(String(repeating: "=", count: 80))
|
||||
|
||||
print("📋 Results Sorted by DER:")
|
||||
print(String(repeating: "-", count: 70))
|
||||
print("Meeting DER % Miss % FA % SE % Speakers RTFx")
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
for result in results.sorted(by: { $0.der < $1.der }) {
|
||||
let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)"
|
||||
let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0)
|
||||
let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0)
|
||||
print(
|
||||
String(
|
||||
format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f",
|
||||
meetingCol,
|
||||
result.der,
|
||||
result.missRate,
|
||||
result.falseAlarmRate,
|
||||
result.speakerErrorRate,
|
||||
speakerCol,
|
||||
result.rtfx))
|
||||
}
|
||||
print(String(repeating: "-", count: 70))
|
||||
|
||||
let count = Float(results.count)
|
||||
let avgDER = results.map { $0.der }.reduce(0, +) / count
|
||||
let avgMiss = results.map { $0.missRate }.reduce(0, +) / count
|
||||
let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count
|
||||
let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count
|
||||
let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count
|
||||
|
||||
print(
|
||||
String(
|
||||
format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f",
|
||||
avgDER, avgMiss, avgFA, avgSE, avgRTFx))
|
||||
print(String(repeating: "=", count: 70))
|
||||
|
||||
print("\n✅ Target Check:")
|
||||
if avgDER < 15 {
|
||||
print(" ✅ DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
} else if avgDER < 20 {
|
||||
print(" 🟡 DER < 20% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
} else {
|
||||
print(" ❌ DER > 20% (achieved: \(String(format: "%.1f", avgDER))%)")
|
||||
}
|
||||
|
||||
if avgRTFx > 1 {
|
||||
print(" ✅ RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
} else {
|
||||
print(" ❌ RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func saveJSONResults(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { result in
|
||||
resultToDict(result)
|
||||
}
|
||||
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
print("💾 JSON results saved to: \(path)")
|
||||
} catch {
|
||||
print("❌ Failed to save JSON: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Progress Save/Load
|
||||
|
||||
private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] {
|
||||
return [
|
||||
"meeting": result.meetingName,
|
||||
"der": result.der,
|
||||
"missRate": result.missRate,
|
||||
"falseAlarmRate": result.falseAlarmRate,
|
||||
"speakerErrorRate": result.speakerErrorRate,
|
||||
"rtfx": result.rtfx,
|
||||
"processingTime": result.processingTime,
|
||||
"totalFrames": result.totalFrames,
|
||||
"detectedSpeakers": result.detectedSpeakers,
|
||||
"groundTruthSpeakers": result.groundTruthSpeakers,
|
||||
"modelLoadTime": result.modelLoadTime,
|
||||
"audioLoadTime": result.audioLoadTime,
|
||||
]
|
||||
}
|
||||
|
||||
private static func saveProgress(results: [BenchmarkResult], to path: String) {
|
||||
let jsonData = results.map { resultToDict($0) }
|
||||
do {
|
||||
let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted)
|
||||
try data.write(to: URL(fileURLWithPath: path))
|
||||
} catch {
|
||||
print("⚠️ Failed to save progress: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func loadProgress(from path: String) -> [BenchmarkResult]? {
|
||||
guard FileManager.default.fileExists(atPath: path) else { return nil }
|
||||
|
||||
do {
|
||||
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
||||
guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonArray.compactMap { dict -> BenchmarkResult? in
|
||||
guard let meeting = dict["meeting"] as? String,
|
||||
let der = (dict["der"] as? NSNumber)?.floatValue,
|
||||
let missRate = (dict["missRate"] as? NSNumber)?.floatValue,
|
||||
let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue,
|
||||
let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue,
|
||||
let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue,
|
||||
let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue,
|
||||
let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue,
|
||||
let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue,
|
||||
let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue,
|
||||
let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue,
|
||||
let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return BenchmarkResult(
|
||||
meetingName: meeting,
|
||||
der: der,
|
||||
missRate: missRate,
|
||||
falseAlarmRate: falseAlarmRate,
|
||||
speakerErrorRate: speakerErrorRate,
|
||||
rtfx: rtfx,
|
||||
processingTime: processingTime,
|
||||
totalFrames: totalFrames,
|
||||
detectedSpeakers: detectedSpeakers,
|
||||
groundTruthSpeakers: groundTruthSpeakers,
|
||||
modelLoadTime: modelLoadTime,
|
||||
audioLoadTime: audioLoadTime
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
print("⚠️ Failed to load progress: \(error)")
|
||||
print("Error processing \(meetingName): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -760,23 +475,11 @@ enum SortformerBenchmark {
|
||||
/// Load ground truth from RTTM file like Python does
|
||||
/// Format: SPEAKER <meeting_id> 1 <start_time> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
||||
private static func loadRTTMGroundTruth(for meetingName: String, dataset: Dataset) -> [TimedSpeakerSegment] {
|
||||
// Determine RTTM path based on dataset
|
||||
let rttmPath: String
|
||||
let homeDir = FileManager.default.homeDirectoryForCurrentUser
|
||||
switch dataset {
|
||||
case .ami:
|
||||
rttmPath = "Streaming-Sortformer-Conversion/\(meetingName).rttm"
|
||||
case .voxconverse:
|
||||
rttmPath =
|
||||
homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/voxconverse/rttm_repo/test/\(meetingName).rttm"
|
||||
).path
|
||||
case .callhome:
|
||||
rttmPath =
|
||||
homeDir.appendingPathComponent(
|
||||
"FluidAudioDatasets/callhome_eng/rttm/\(meetingName).rttm"
|
||||
).path
|
||||
guard let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset) else {
|
||||
print(" [RTTM] No RTTM URL for \(meetingName)")
|
||||
return []
|
||||
}
|
||||
let rttmPath = rttmURL.path
|
||||
|
||||
guard FileManager.default.fileExists(atPath: rttmPath) else {
|
||||
print(" [RTTM] File not found: \(rttmPath)")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#if os(macOS)
|
||||
import AVFoundation
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
@@ -131,9 +130,10 @@ enum SortformerCommand {
|
||||
print(
|
||||
"[DEBUG] First 10 audio samples: \((0..<min(10, audioSamples.count)).map { String(format: "%.6f", audioSamples[$0]) }.joined(separator: ", "))"
|
||||
)
|
||||
let debugPath = NSTemporaryDirectory() + "swift_audio_16k.bin"
|
||||
let audioData = audioSamples.withUnsafeBytes { Data($0) }
|
||||
try? audioData.write(to: URL(fileURLWithPath: "swift_audio_16k.bin"))
|
||||
print("[DEBUG] Saved \(audioSamples.count) samples to swift_audio_16k.bin")
|
||||
try? audioData.write(to: URL(fileURLWithPath: debugPath))
|
||||
print("[DEBUG] Saved \(audioSamples.count) samples to \(debugPath)")
|
||||
}
|
||||
|
||||
// Process with progress
|
||||
@@ -178,12 +178,13 @@ enum SortformerCommand {
|
||||
|
||||
// Print speaker probabilities summary
|
||||
print("\n--- Speaker Activity Summary ---")
|
||||
let numSpeakers = 4
|
||||
let numSpeakers = result.config.numSpeakers
|
||||
var speakerActivity = [Float](repeating: 0, count: numSpeakers)
|
||||
let predictions = result.finalizedPredictions
|
||||
for frame in 0..<result.numFinalizedFrames {
|
||||
for spk in 0..<numSpeakers {
|
||||
let prob = result.finalizedPredictions[frame * numSpeakers + spk]
|
||||
if prob > 0.5 {
|
||||
let idx = frame * numSpeakers + spk
|
||||
if idx < predictions.count, predictions[idx] > 0.5 {
|
||||
speakerActivity[spk] += result.config.frameDurationSeconds
|
||||
}
|
||||
}
|
||||
@@ -233,7 +234,7 @@ enum SortformerCommand {
|
||||
}
|
||||
|
||||
private static func printUsage() {
|
||||
logger.info(
|
||||
print(
|
||||
"""
|
||||
|
||||
Sortformer Command Usage:
|
||||
@@ -259,8 +260,7 @@ enum SortformerCommand {
|
||||
|
||||
# Save results to file
|
||||
fluidaudio sortformer audio.wav --output results.json
|
||||
"""
|
||||
)
|
||||
""")
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
+3
-3
@@ -3,13 +3,13 @@ import XCTest
|
||||
|
||||
@testable import FluidAudio
|
||||
|
||||
final class NeMoMelSpectrogramTests: XCTestCase {
|
||||
final class AudioMelSpectrogramTests: XCTestCase {
|
||||
|
||||
private var mel: NeMoMelSpectrogram!
|
||||
private var mel: AudioMelSpectrogram!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
mel = NeMoMelSpectrogram()
|
||||
mel = AudioMelSpectrogram()
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
@@ -0,0 +1,136 @@
|
||||
import AVFoundation
|
||||
import Foundation
|
||||
import XCTest
|
||||
|
||||
@testable import FluidAudio
|
||||
|
||||
/// Shared fixture infrastructure for diarization tests.
|
||||
///
|
||||
/// Generates a deterministic multi-segment waveform with silence gaps, writes it to a
|
||||
/// temporary WAV file, and caches it for reuse across tests within the same process.
|
||||
enum DiarizationTestFixtures {
|
||||
static let fixtureSampleRate = 16_000
|
||||
|
||||
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
|
||||
|
||||
/// Returns a cached URL to the fixture WAV file, creating it on first access.
|
||||
static func fixtureAudioFileURL() throws -> URL {
|
||||
if let cached = cachedFixtureAudioURL,
|
||||
FileManager.default.fileExists(atPath: cached.path)
|
||||
{
|
||||
return cached
|
||||
}
|
||||
|
||||
let url = FileManager.default.temporaryDirectory
|
||||
.appendingPathComponent("diarization-fixture-\(UUID().uuidString)")
|
||||
.appendingPathExtension("wav")
|
||||
try writeFixtureAudio(to: url)
|
||||
cachedFixtureAudioURL = url
|
||||
return url
|
||||
}
|
||||
|
||||
/// Loads fixture audio resampled to the given sample rate, optionally limited to a duration.
|
||||
static func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] {
|
||||
let converter = AudioConverter(sampleRate: Double(sampleRate))
|
||||
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
|
||||
guard let limitSeconds else {
|
||||
return audio
|
||||
}
|
||||
let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate)))
|
||||
return Array(audio.prefix(sampleCount))
|
||||
}
|
||||
|
||||
/// Loads a slice of fixture audio at the given sample rate.
|
||||
static func fixtureAudio(
|
||||
sampleRate: Int, startSeconds: Double, durationSeconds: Double
|
||||
) throws -> [Float] {
|
||||
let converter = AudioConverter(sampleRate: Double(sampleRate))
|
||||
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
|
||||
let startSample = min(audio.count, Int(startSeconds * Double(sampleRate)))
|
||||
let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate)))
|
||||
return Array(audio[startSample..<endSample])
|
||||
}
|
||||
|
||||
/// Splits samples into chunks with rotating sizes.
|
||||
static func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
|
||||
var chunks: [[Float]] = []
|
||||
var start = 0
|
||||
var index = 0
|
||||
while start < samples.count {
|
||||
let size = sizes[index % sizes.count]
|
||||
let stop = min(samples.count, start + size)
|
||||
chunks.append(Array(samples[start..<stop]))
|
||||
start = stop
|
||||
index += 1
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private static func writeFixtureAudio(to url: URL) throws {
|
||||
let sampleRate = Double(fixtureSampleRate)
|
||||
let samples = makeFixtureSamples(sampleRate: sampleRate)
|
||||
let format = AVAudioFormat(
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
sampleRate: sampleRate,
|
||||
channels: 1,
|
||||
interleaved: false
|
||||
)!
|
||||
guard
|
||||
let buffer = AVAudioPCMBuffer(
|
||||
pcmFormat: format,
|
||||
frameCapacity: AVAudioFrameCount(samples.count)
|
||||
)
|
||||
else {
|
||||
XCTFail("Failed to allocate fixture audio buffer")
|
||||
return
|
||||
}
|
||||
|
||||
buffer.frameLength = AVAudioFrameCount(samples.count)
|
||||
samples.withUnsafeBufferPointer { source in
|
||||
guard let destination = buffer.floatChannelData?[0] else { return }
|
||||
destination.update(from: source.baseAddress!, count: samples.count)
|
||||
}
|
||||
|
||||
let file = try AVAudioFile(
|
||||
forWriting: url,
|
||||
settings: format.settings,
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
interleaved: false
|
||||
)
|
||||
try file.write(from: buffer)
|
||||
}
|
||||
|
||||
private static func makeFixtureSamples(sampleRate: Double) -> [Float] {
|
||||
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
|
||||
(1.0, 0.20, 220),
|
||||
(0.35, 0.00, 0),
|
||||
(1.1, 0.32, 330),
|
||||
(0.25, 0.00, 0),
|
||||
(1.0, 0.28, 180),
|
||||
(0.40, 0.00, 0),
|
||||
(1.3, 0.36, 260),
|
||||
(0.30, 0.00, 0),
|
||||
(1.1, 0.24, 410),
|
||||
]
|
||||
|
||||
var output: [Float] = []
|
||||
for (duration, amplitude, frequency) in segments {
|
||||
let frameCount = Int(duration * sampleRate)
|
||||
guard amplitude > 0, frequency > 0 else {
|
||||
output.append(contentsOf: repeatElement(0, count: frameCount))
|
||||
continue
|
||||
}
|
||||
|
||||
for frame in 0..<frameCount {
|
||||
let time = Double(frame) / sampleRate
|
||||
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
|
||||
let carrier = sin(2.0 * Double.pi * frequency * time)
|
||||
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
|
||||
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
import AVFoundation
|
||||
import CoreML
|
||||
import Foundation
|
||||
import XCTest
|
||||
@@ -11,8 +10,6 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
let meanAbs: Double
|
||||
}
|
||||
|
||||
private static let fixtureSampleRate = 16_000
|
||||
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
|
||||
nonisolated(unsafe) private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:]
|
||||
|
||||
func testVariantRegistryResolvesAllExportedArtifacts() async throws {
|
||||
@@ -39,7 +36,8 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
func testOfflineInferenceProducesConsistentShapesAcrossVariants() async throws {
|
||||
for variant in LSEENDVariant.allCases {
|
||||
let engine = try await makeEngine(variant: variant)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 2.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, limitSeconds: 2.0)
|
||||
let result = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
|
||||
|
||||
try assertResultInvariants(
|
||||
@@ -53,8 +51,8 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testAudioFileInferenceMatchesInferenceOnResampledFixtureSamples() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let fileResult = try engine.infer(audioFileURL: try fixtureAudioFileURL())
|
||||
let resampled = try fixtureAudio(sampleRate: engine.targetSampleRate)
|
||||
let fileResult = try engine.infer(audioFileURL: try DiarizationTestFixtures.fixtureAudioFileURL())
|
||||
let resampled = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate)
|
||||
let sampleResult = try engine.infer(samples: resampled, sampleRate: engine.targetSampleRate)
|
||||
|
||||
assertMatrixClose(fileResult.logits, sampleResult.logits, maxAbs: 1e-6, meanAbs: 1e-7)
|
||||
@@ -65,7 +63,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testStreamingSessionMatchesOfflineInferenceOnRealFixtureAudio() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let offline = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
|
||||
let session = try engine.createSession(inputSampleRate: engine.targetSampleRate)
|
||||
|
||||
@@ -115,7 +113,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testStreamingSimulationMatchesOfflineInferenceAndReportsMonotonicProgress() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let fixtureURL = try fixtureAudioFileURL()
|
||||
let fixtureURL = try DiarizationTestFixtures.fixtureAudioFileURL()
|
||||
let offline = try engine.infer(audioFileURL: fixtureURL)
|
||||
let simulation = try engine.simulateStreaming(audioFileURL: fixtureURL, chunkSeconds: 0.37)
|
||||
|
||||
@@ -142,7 +140,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testDiarizerProcessCompleteMatchesEngineInference() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let expected = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
@@ -159,13 +157,13 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testDiarizerStreamingFinalizeMatchesProcessComplete() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 4.0)
|
||||
let expected = try engine.infer(samples: samples, sampleRate: engine.targetSampleRate)
|
||||
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
|
||||
for chunk in chunk(samples, sizes: [701, 977, 1153]) {
|
||||
for chunk in DiarizationTestFixtures.chunk(samples, sizes: [701, 977, 1153]) {
|
||||
let _ = try diarizer.process(samples: chunk)
|
||||
}
|
||||
let _ = try diarizer.finalizeSession()
|
||||
@@ -195,7 +193,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testEnrollSpeakerResetsVisibleTimelineAndAllowsStreaming() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
|
||||
let enrollmentCount = min(samples.count / 2, engine.targetSampleRate * 2)
|
||||
let enrollment = Array(samples.prefix(enrollmentCount))
|
||||
let live = Array(samples.dropFirst(enrollmentCount))
|
||||
@@ -212,7 +210,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
|
||||
|
||||
var firstUpdate: DiarizerTimelineUpdate?
|
||||
for chunk in chunk(live, sizes: [977, 1231, 1607]) {
|
||||
for chunk in DiarizationTestFixtures.chunk(live, sizes: [977, 1231, 1607]) {
|
||||
if let update = try diarizer.process(samples: chunk) {
|
||||
firstUpdate = update
|
||||
break
|
||||
@@ -231,7 +229,7 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
|
||||
func testProcessCompleteKeepsPrimedSessionOnlyWhenRequested() async throws {
|
||||
let engine = try await makeEngine(variant: .dihard3)
|
||||
let samples = try fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
|
||||
let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: engine.targetSampleRate, limitSeconds: 6.0)
|
||||
let enrollmentSampleCount = engine.targetSampleRate * 2
|
||||
let enrollment = Array(samples.prefix(enrollmentSampleCount))
|
||||
let complete = Array(samples.dropFirst(enrollmentSampleCount).prefix(enrollmentSampleCount))
|
||||
@@ -240,13 +238,13 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
diarizer.initialize(engine: engine)
|
||||
|
||||
_ = try diarizer.enrollSpeaker(withSamples: enrollment, named: "Alice")
|
||||
XCTAssertTrue(hasActiveSession(diarizer))
|
||||
XCTAssertTrue(diarizer.hasActiveSession)
|
||||
|
||||
_ = try diarizer.processComplete(complete, keepingEnrolledSpeakers: true)
|
||||
XCTAssertFalse(hasActiveSession(diarizer))
|
||||
XCTAssertFalse(diarizer.hasActiveSession)
|
||||
|
||||
_ = try diarizer.processComplete(complete, keepingEnrolledSpeakers: false)
|
||||
XCTAssertFalse(hasActiveSession(diarizer))
|
||||
XCTAssertFalse(diarizer.hasActiveSession)
|
||||
}
|
||||
|
||||
private func makeEngine(variant: LSEENDVariant) async throws -> LSEENDInferenceHelper {
|
||||
@@ -259,115 +257,10 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
return engine
|
||||
}
|
||||
|
||||
private func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] {
|
||||
let converter = AudioConverter(sampleRate: Double(sampleRate))
|
||||
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
|
||||
guard let limitSeconds else {
|
||||
return audio
|
||||
}
|
||||
let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate)))
|
||||
return Array(audio.prefix(sampleCount))
|
||||
}
|
||||
|
||||
private func fixtureAudioFileURL() throws -> URL {
|
||||
if let cached = Self.cachedFixtureAudioURL,
|
||||
FileManager.default.fileExists(atPath: cached.path)
|
||||
{
|
||||
return cached
|
||||
}
|
||||
|
||||
let url = FileManager.default.temporaryDirectory
|
||||
.appendingPathComponent("lseend-fixture-\(UUID().uuidString)")
|
||||
.appendingPathExtension("wav")
|
||||
try writeFixtureAudio(to: url)
|
||||
Self.cachedFixtureAudioURL = url
|
||||
return url
|
||||
}
|
||||
|
||||
private func writeFixtureAudio(to url: URL) throws {
|
||||
let sampleRate = Double(Self.fixtureSampleRate)
|
||||
let samples = makeFixtureSamples(sampleRate: sampleRate)
|
||||
let format = AVAudioFormat(
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
sampleRate: sampleRate,
|
||||
channels: 1,
|
||||
interleaved: false
|
||||
)!
|
||||
guard
|
||||
let buffer = AVAudioPCMBuffer(
|
||||
pcmFormat: format,
|
||||
frameCapacity: AVAudioFrameCount(samples.count)
|
||||
)
|
||||
else {
|
||||
XCTFail("Failed to allocate fixture audio buffer")
|
||||
return
|
||||
}
|
||||
|
||||
buffer.frameLength = AVAudioFrameCount(samples.count)
|
||||
samples.withUnsafeBufferPointer { source in
|
||||
guard let destination = buffer.floatChannelData?[0] else { return }
|
||||
destination.update(from: source.baseAddress!, count: samples.count)
|
||||
}
|
||||
|
||||
let file = try AVAudioFile(
|
||||
forWriting: url,
|
||||
settings: format.settings,
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
interleaved: false
|
||||
)
|
||||
try file.write(from: buffer)
|
||||
}
|
||||
|
||||
private func makeFixtureSamples(sampleRate: Double) -> [Float] {
|
||||
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
|
||||
(1.0, 0.20, 220),
|
||||
(0.35, 0.00, 0),
|
||||
(1.1, 0.32, 330),
|
||||
(0.25, 0.00, 0),
|
||||
(1.0, 0.28, 180),
|
||||
(0.40, 0.00, 0),
|
||||
(1.3, 0.36, 260),
|
||||
(0.30, 0.00, 0),
|
||||
(1.1, 0.24, 410),
|
||||
]
|
||||
|
||||
var output: [Float] = []
|
||||
for (duration, amplitude, frequency) in segments {
|
||||
let frameCount = Int(duration * sampleRate)
|
||||
guard amplitude > 0, frequency > 0 else {
|
||||
output.append(contentsOf: repeatElement(0, count: frameCount))
|
||||
continue
|
||||
}
|
||||
|
||||
for frame in 0..<frameCount {
|
||||
let time = Double(frame) / sampleRate
|
||||
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
|
||||
let carrier = sin(2.0 * Double.pi * frequency * time)
|
||||
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
|
||||
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
private func duration(of samples: [Float], sampleRate: Int) -> Double {
|
||||
Double(samples.count) / Double(sampleRate)
|
||||
}
|
||||
|
||||
private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
|
||||
var chunks: [[Float]] = []
|
||||
var start = 0
|
||||
var index = 0
|
||||
while start < samples.count {
|
||||
let size = sizes[index % sizes.count]
|
||||
let stop = min(samples.count, start + size)
|
||||
chunks.append(Array(samples[start..<stop]))
|
||||
start = stop
|
||||
index += 1
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
private func assertResultInvariants(
|
||||
_ result: LSEENDInferenceResult,
|
||||
engine: LSEENDInferenceHelper,
|
||||
@@ -433,15 +326,4 @@ final class LSEENDIntegrationTests: XCTestCase {
|
||||
)
|
||||
}
|
||||
|
||||
private func hasActiveSession(_ diarizer: LSEENDDiarizer) -> Bool {
|
||||
let mirror = Mirror(reflecting: diarizer)
|
||||
guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else {
|
||||
XCTFail("Expected LS-EEND diarizer to expose _session via reflection")
|
||||
return false
|
||||
}
|
||||
|
||||
let optionalMirror = Mirror(reflecting: sessionValue)
|
||||
XCTAssertEqual(optionalMirror.displayStyle, .optional)
|
||||
return optionalMirror.children.count == 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
import XCTest
|
||||
|
||||
@testable import FluidAudio
|
||||
|
||||
final class LSEENDMatrixTests: XCTestCase {
|
||||
|
||||
// MARK: - Init (validated)
|
||||
|
||||
func testInitWithMatchingDimensions() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
|
||||
XCTAssertEqual(m.rows, 2)
|
||||
XCTAssertEqual(m.columns, 3)
|
||||
XCTAssertEqual(m.values, [1, 2, 3, 4, 5, 6])
|
||||
}
|
||||
|
||||
func testInitThrowsOnCountMismatch() {
|
||||
XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3])) { error in
|
||||
guard case LSEENDError.invalidMatrixShape = error else {
|
||||
return XCTFail("Expected invalidMatrixShape, got \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testInitThrowsOnNegativeRows() {
|
||||
XCTAssertThrowsError(try LSEENDMatrix(rows: -1, columns: 3, values: [])) { error in
|
||||
guard case LSEENDError.invalidMatrixShape = error else {
|
||||
return XCTFail("Expected invalidMatrixShape, got \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testInitThrowsOnNegativeColumns() {
|
||||
XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: -1, values: [])) { error in
|
||||
guard case LSEENDError.invalidMatrixShape = error else {
|
||||
return XCTFail("Expected invalidMatrixShape, got \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testInitWithZeroDimensions() throws {
|
||||
let m = try LSEENDMatrix(rows: 0, columns: 5, values: [])
|
||||
XCTAssertEqual(m.rows, 0)
|
||||
XCTAssertEqual(m.columns, 5)
|
||||
XCTAssertTrue(m.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - Factory Methods
|
||||
|
||||
func testZeros() {
|
||||
let m = LSEENDMatrix.zeros(rows: 3, columns: 2)
|
||||
XCTAssertEqual(m.rows, 3)
|
||||
XCTAssertEqual(m.columns, 2)
|
||||
XCTAssertEqual(m.values, [Float](repeating: 0, count: 6))
|
||||
}
|
||||
|
||||
func testEmpty() {
|
||||
let m = LSEENDMatrix.empty(columns: 4)
|
||||
XCTAssertEqual(m.rows, 0)
|
||||
XCTAssertEqual(m.columns, 4)
|
||||
XCTAssertTrue(m.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - isEmpty
|
||||
|
||||
func testIsEmptyZeroRows() {
|
||||
XCTAssertTrue(LSEENDMatrix.empty(columns: 3).isEmpty)
|
||||
}
|
||||
|
||||
func testIsEmptyZeroColumns() {
|
||||
let m = LSEENDMatrix(validatingRows: 3, columns: 0, values: [])
|
||||
XCTAssertTrue(m.isEmpty)
|
||||
}
|
||||
|
||||
func testIsEmptyFalseForPopulatedMatrix() throws {
|
||||
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [42])
|
||||
XCTAssertFalse(m.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - Subscript
|
||||
|
||||
func testSubscriptGet() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [10, 20, 30, 40, 50, 60])
|
||||
XCTAssertEqual(m[0, 0], 10)
|
||||
XCTAssertEqual(m[0, 2], 30)
|
||||
XCTAssertEqual(m[1, 0], 40)
|
||||
XCTAssertEqual(m[1, 2], 60)
|
||||
}
|
||||
|
||||
func testSubscriptSet() throws {
|
||||
var m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
m[1, 0] = 99
|
||||
XCTAssertEqual(m[1, 0], 99)
|
||||
XCTAssertEqual(m.values, [1, 2, 99, 4])
|
||||
}
|
||||
|
||||
// MARK: - row()
|
||||
|
||||
func testRow() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
XCTAssertEqual(Array(m.row(0)), [1, 2])
|
||||
XCTAssertEqual(Array(m.row(1)), [3, 4])
|
||||
XCTAssertEqual(Array(m.row(2)), [5, 6])
|
||||
}
|
||||
|
||||
// MARK: - prefixingColumns
|
||||
|
||||
func testPrefixingColumns() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 4, values: [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
let prefix = m.prefixingColumns(2)
|
||||
XCTAssertEqual(prefix.rows, 2)
|
||||
XCTAssertEqual(prefix.columns, 2)
|
||||
XCTAssertEqual(prefix.values, [1, 2, 5, 6])
|
||||
}
|
||||
|
||||
func testPrefixingColumnsEqualToWidth() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
|
||||
let same = m.prefixingColumns(3)
|
||||
XCTAssertEqual(same, m)
|
||||
}
|
||||
|
||||
func testPrefixingColumnsGreaterThanWidth() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
|
||||
let same = m.prefixingColumns(10)
|
||||
XCTAssertEqual(same, m)
|
||||
}
|
||||
|
||||
func testPrefixingColumnsZero() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
|
||||
let empty = m.prefixingColumns(0)
|
||||
XCTAssertTrue(empty.isEmpty)
|
||||
}
|
||||
|
||||
func testPrefixingColumnsOnEmptyMatrix() {
|
||||
let m = LSEENDMatrix.empty(columns: 4)
|
||||
let result = m.prefixingColumns(2)
|
||||
XCTAssertTrue(result.isEmpty)
|
||||
XCTAssertEqual(result.columns, 2)
|
||||
}
|
||||
|
||||
// MARK: - rowMajorRows
|
||||
|
||||
func testRowMajorRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6])
|
||||
let rows = m.rowMajorRows()
|
||||
XCTAssertEqual(rows, [[1, 2, 3], [4, 5, 6]])
|
||||
}
|
||||
|
||||
func testRowMajorRowsEmpty() {
|
||||
let m = LSEENDMatrix.empty(columns: 3)
|
||||
XCTAssertEqual(m.rowMajorRows(), [])
|
||||
}
|
||||
|
||||
// MARK: - appendingRows
|
||||
|
||||
func testAppendingRows() throws {
|
||||
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let b = try LSEENDMatrix(rows: 1, columns: 2, values: [5, 6])
|
||||
let result = a.appendingRows(b)
|
||||
XCTAssertEqual(result.rows, 3)
|
||||
XCTAssertEqual(result.columns, 2)
|
||||
XCTAssertEqual(result.values, [1, 2, 3, 4, 5, 6])
|
||||
}
|
||||
|
||||
func testAppendingRowsToEmpty() throws {
|
||||
let a = LSEENDMatrix.empty(columns: 2)
|
||||
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
XCTAssertEqual(a.appendingRows(b), b)
|
||||
}
|
||||
|
||||
func testAppendingEmptyRows() throws {
|
||||
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let b = LSEENDMatrix.empty(columns: 2)
|
||||
XCTAssertEqual(a.appendingRows(b), a)
|
||||
}
|
||||
|
||||
// MARK: - droppingFirstRows
|
||||
|
||||
func testDroppingFirstRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
let dropped = m.droppingFirstRows(2)
|
||||
XCTAssertEqual(dropped.rows, 2)
|
||||
XCTAssertEqual(dropped.columns, 2)
|
||||
XCTAssertEqual(dropped.values, [5, 6, 7, 8])
|
||||
}
|
||||
|
||||
func testDroppingAllRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let dropped = m.droppingFirstRows(3)
|
||||
XCTAssertEqual(dropped.rows, 0)
|
||||
XCTAssertTrue(dropped.isEmpty)
|
||||
}
|
||||
|
||||
func testDroppingMoreThanTotalRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let dropped = m.droppingFirstRows(100)
|
||||
XCTAssertEqual(dropped.rows, 0)
|
||||
XCTAssertTrue(dropped.isEmpty)
|
||||
}
|
||||
|
||||
func testDroppingZeroRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let same = m.droppingFirstRows(0)
|
||||
XCTAssertEqual(same, m)
|
||||
}
|
||||
|
||||
func testDroppingNegativeCount() throws {
|
||||
let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let same = m.droppingFirstRows(-5)
|
||||
XCTAssertEqual(same, m)
|
||||
}
|
||||
|
||||
// MARK: - slicingRows
|
||||
|
||||
func testSlicingRows() throws {
|
||||
let m = try LSEENDMatrix(rows: 5, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
let slice = m.slicingRows(start: 1, end: 4)
|
||||
XCTAssertEqual(slice.rows, 3)
|
||||
XCTAssertEqual(slice.values, [3, 4, 5, 6, 7, 8])
|
||||
}
|
||||
|
||||
func testSlicingRowsFullRange() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let slice = m.slicingRows(start: 0, end: 3)
|
||||
XCTAssertEqual(slice, m)
|
||||
}
|
||||
|
||||
func testSlicingRowsEmptyRange() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let slice = m.slicingRows(start: 2, end: 2)
|
||||
XCTAssertTrue(slice.isEmpty)
|
||||
XCTAssertEqual(slice.columns, 2)
|
||||
}
|
||||
|
||||
func testSlicingRowsClampsOutOfBounds() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let slice = m.slicingRows(start: -5, end: 100)
|
||||
XCTAssertEqual(slice, m)
|
||||
}
|
||||
|
||||
func testSlicingRowsInvertedRange() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let slice = m.slicingRows(start: 3, end: 1)
|
||||
XCTAssertTrue(slice.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - applyingSigmoid
|
||||
|
||||
func testSigmoidZero() throws {
|
||||
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [0])
|
||||
let s = m.applyingSigmoid()
|
||||
XCTAssertEqual(s[0, 0], 0.5, accuracy: 1e-6)
|
||||
}
|
||||
|
||||
func testSigmoidLargePositive() throws {
|
||||
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [20])
|
||||
let s = m.applyingSigmoid()
|
||||
XCTAssertEqual(s[0, 0], 1.0, accuracy: 1e-5)
|
||||
}
|
||||
|
||||
func testSigmoidLargeNegative() throws {
|
||||
let m = try LSEENDMatrix(rows: 1, columns: 1, values: [-20])
|
||||
let s = m.applyingSigmoid()
|
||||
XCTAssertEqual(s[0, 0], 0.0, accuracy: 1e-5)
|
||||
}
|
||||
|
||||
func testSigmoidPreservesShape() throws {
|
||||
let m = try LSEENDMatrix(rows: 3, columns: 4, values: [Float](repeating: 0, count: 12))
|
||||
let s = m.applyingSigmoid()
|
||||
XCTAssertEqual(s.rows, 3)
|
||||
XCTAssertEqual(s.columns, 4)
|
||||
XCTAssertEqual(s.values.count, 12)
|
||||
}
|
||||
|
||||
func testSigmoidDoesNotMutateOriginal() throws {
|
||||
let m = try LSEENDMatrix(rows: 1, columns: 2, values: [0, 0])
|
||||
_ = m.applyingSigmoid()
|
||||
XCTAssertEqual(m.values, [0, 0])
|
||||
}
|
||||
|
||||
func testSigmoidOnEmpty() {
|
||||
let m = LSEENDMatrix.empty(columns: 3)
|
||||
let s = m.applyingSigmoid()
|
||||
XCTAssertTrue(s.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - Equatable
|
||||
|
||||
func testEqualMatrices() throws {
|
||||
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
XCTAssertEqual(a, b)
|
||||
}
|
||||
|
||||
func testUnequalValues() throws {
|
||||
let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4])
|
||||
let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 5])
|
||||
XCTAssertNotEqual(a, b)
|
||||
}
|
||||
|
||||
// MARK: - Roundtrip: append then drop
|
||||
|
||||
func testAppendThenDropRecoversOriginal() throws {
|
||||
let original = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6])
|
||||
let extra = try LSEENDMatrix(rows: 2, columns: 2, values: [7, 8, 9, 10])
|
||||
let combined = original.appendingRows(extra)
|
||||
let recovered = combined.slicingRows(start: 0, end: 3)
|
||||
XCTAssertEqual(recovered, original)
|
||||
}
|
||||
|
||||
func testSliceThenAppendRecombines() throws {
|
||||
let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
let head = m.slicingRows(start: 0, end: 2)
|
||||
let tail = m.slicingRows(start: 2, end: 4)
|
||||
let recombined = head.appendingRows(tail)
|
||||
XCTAssertEqual(recombined, m)
|
||||
}
|
||||
|
||||
func testDropThenPrefixColumnsCommutes() throws {
|
||||
let m = try LSEENDMatrix(rows: 4, columns: 4, values: (0..<16).map { Float($0) })
|
||||
|
||||
let dropFirst = m.droppingFirstRows(2).prefixingColumns(2)
|
||||
let prefixFirst = m.prefixingColumns(2).droppingFirstRows(2)
|
||||
|
||||
XCTAssertEqual(dropFirst, prefixFirst)
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ final class SortformerTests: XCTestCase {
|
||||
// Create 5 seconds of deterministic random audio
|
||||
let sampleRate = 16000
|
||||
let audioCount = sampleRate * 5
|
||||
srand48(Int(Date().timeIntervalSince1970 * 1e6))
|
||||
srand48(42)
|
||||
let audio = (0..<audioCount).map { _ in Float(drand48() - 0.5) }
|
||||
|
||||
// 1. Get chunks from Batch Feature Provider
|
||||
@@ -36,7 +36,7 @@ final class SortformerTests: XCTestCase {
|
||||
// 3. Compare
|
||||
XCTAssertEqual(batchChunks.count, streamingChunks.count, "Chunk count mismatch")
|
||||
|
||||
for i in 0..<min(batchChunks.count, streamingChunks.count) {
|
||||
for i in 0..<batchChunks.count {
|
||||
let batch = batchChunks[i]
|
||||
let stream = streamingChunks[i]
|
||||
|
||||
@@ -140,18 +140,15 @@ final class SortformerTests: XCTestCase {
|
||||
func testBufferBounds() throws {
|
||||
var config = DiarizerTimelineConfig.sortformerDefault
|
||||
let numSpeakers = config.numSpeakers
|
||||
config.maxStoredFrames = 50
|
||||
let maxFrames = 50
|
||||
config.maxStoredFrames = maxFrames
|
||||
|
||||
// Create timeline with maxFrames limit
|
||||
let timeline = DiarizerTimeline(config: config)
|
||||
|
||||
// Feed 200 frames of predictions (way more than maxFrames)
|
||||
let totalFrames = 200
|
||||
for frameOffset in stride(from: 0, to: totalFrames, by: 10) {
|
||||
var chunkPreds: [Float] = []
|
||||
for _ in 0..<10 {
|
||||
chunkPreds.append(contentsOf: [Float](repeating: 0.5, count: numSpeakers))
|
||||
}
|
||||
for frameOffset in stride(from: 0, to: 200, by: 10) {
|
||||
let chunkPreds = [Float](repeating: 0.5, count: 10 * numSpeakers)
|
||||
|
||||
let chunk = DiarizerChunkResult(
|
||||
startFrame: frameOffset,
|
||||
@@ -165,16 +162,25 @@ final class SortformerTests: XCTestCase {
|
||||
|
||||
// Verify framePredictions is bounded to maxFrames
|
||||
XCTAssertLessThanOrEqual(
|
||||
timeline.finalizedPredictions.count, config.maxStoredFrames! * config.numSpeakers,
|
||||
timeline.finalizedPredictions.count, maxFrames * numSpeakers,
|
||||
"framePredictions should be bounded to maxFrames")
|
||||
|
||||
// Verify we still have some predictions (not all trimmed)
|
||||
XCTAssertGreaterThan(timeline.numFinalizedFrames, 0, "Should have some predictions")
|
||||
|
||||
// Verify probability() returns valid values for stored frames and NaN for trimmed frames
|
||||
let storedFrames = timeline.finalizedPredictions.count / numSpeakers
|
||||
let firstStoredFrame = timeline.numFinalizedFrames - storedFrames
|
||||
XCTAssertFalse(
|
||||
timeline.probability(speaker: 0, frame: firstStoredFrame).isNaN,
|
||||
"First stored frame should have a valid probability")
|
||||
XCTAssertTrue(
|
||||
timeline.probability(speaker: 0, frame: firstStoredFrame - 1).isNaN,
|
||||
"Frame before stored range should return NaN")
|
||||
}
|
||||
|
||||
func testSegmentExtraction() throws {
|
||||
let config = SortformerConfig.default
|
||||
let config = DiarizerTimelineConfig.sortformerDefault
|
||||
let numSpeakers = config.numSpeakers
|
||||
|
||||
// Create predictions with clear speaker pattern:
|
||||
@@ -198,12 +204,13 @@ final class SortformerTests: XCTestCase {
|
||||
|
||||
let timeline = try DiarizerTimeline(
|
||||
allPredictions: predictions,
|
||||
config: .sortformerDefault,
|
||||
config: config,
|
||||
isComplete: true
|
||||
)
|
||||
|
||||
// Check that we have segments
|
||||
XCTAssertGreaterThan(timeline.speakers.count, 0, "Should have extracted segments")
|
||||
// Check that segments were actually extracted (not just that the speakers dict exists)
|
||||
let totalSegments = timeline.speakers.values.reduce(0) { $0 + $1.finalizedSegmentCount }
|
||||
XCTAssertGreaterThan(totalSegments, 0, "Should have extracted at least one segment")
|
||||
|
||||
// Verify segment speakers are valid
|
||||
for speaker in timeline.speakers.values {
|
||||
@@ -216,10 +223,10 @@ final class SortformerTests: XCTestCase {
|
||||
}
|
||||
|
||||
func testReset() throws {
|
||||
let config = SortformerConfig.default
|
||||
let config = DiarizerTimelineConfig.sortformerDefault
|
||||
let numSpeakers = config.numSpeakers
|
||||
|
||||
let timeline = DiarizerTimeline(config: .sortformerDefault)
|
||||
let timeline = DiarizerTimeline(config: config)
|
||||
|
||||
// Add some data
|
||||
let chunk = DiarizerChunkResult(
|
||||
|
||||
@@ -58,9 +58,9 @@ final class SortformerTimelineTests: XCTestCase {
|
||||
XCTAssertEqual(timeline.numFinalizedFrames, 18, "3 chunks of 6 frames = 18")
|
||||
}
|
||||
|
||||
// MARK: - Segment Generation
|
||||
// MARK: - Prediction Storage
|
||||
|
||||
func testHighProbabilityUpdatesFramePredictions() throws {
|
||||
func testHighProbabilityStoresPredictions() throws {
|
||||
let timeline = DiarizerTimeline(config: .sortformerDefault)
|
||||
let numSpeakers = 4
|
||||
let frameCount = 12
|
||||
@@ -139,7 +139,6 @@ final class SortformerTimelineTests: XCTestCase {
|
||||
// MARK: - Probability Access
|
||||
|
||||
func testProbabilityAccess() throws {
|
||||
let numSpeakers = 4
|
||||
// [f0s0=0.1, f0s1=0.2, f0s2=0.3, f0s3=0.4, f1s0=0.5, ...]
|
||||
let predictions: [Float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||
let timeline = try DiarizerTimeline(
|
||||
@@ -155,7 +154,6 @@ final class SortformerTimelineTests: XCTestCase {
|
||||
timeline.probability(speaker: 0, frame: 999).isNaN,
|
||||
"Out of range should return NaN"
|
||||
)
|
||||
_ = numSpeakers // used above
|
||||
}
|
||||
|
||||
// MARK: - SortformerSegment
|
||||
@@ -172,9 +170,11 @@ final class SortformerTimelineTests: XCTestCase {
|
||||
let a = DiarizerSegment(speakerIndex: 0, startFrame: 0, endFrame: 10, frameDurationSeconds: 0.08)
|
||||
let b = DiarizerSegment(speakerIndex: 0, startFrame: 5, endFrame: 15, frameDurationSeconds: 0.08)
|
||||
let c = DiarizerSegment(speakerIndex: 0, startFrame: 11, endFrame: 20, frameDurationSeconds: 0.08)
|
||||
let d = DiarizerSegment(speakerIndex: 0, startFrame: 10, endFrame: 20, frameDurationSeconds: 0.08)
|
||||
|
||||
XCTAssertTrue(a.overlaps(with: b), "Overlapping segments")
|
||||
XCTAssertFalse(a.overlaps(with: c), "Non-overlapping segments")
|
||||
XCTAssertFalse(a.overlaps(with: c), "Non-overlapping segments (gap of 1)")
|
||||
XCTAssertTrue(a.overlaps(with: d), "Touching segments (endFrame == startFrame) count as overlapping")
|
||||
}
|
||||
|
||||
func testSegmentAbsorb() {
|
||||
|
||||
@@ -60,7 +60,7 @@ final class SortformerTypesTests: XCTestCase {
|
||||
|
||||
func testConfigIncompatibleWithDifferentShape() {
|
||||
let a = SortformerConfig.default
|
||||
let b = SortformerConfig.nvidiaHighLatencyV2_1
|
||||
let b = SortformerConfig.highContextV2_1
|
||||
XCTAssertFalse(a.isCompatible(with: b))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import AVFoundation
|
||||
import Foundation
|
||||
import XCTest
|
||||
|
||||
@@ -9,8 +8,6 @@ import XCTest
|
||||
/// - `SortformerDiarizer.enrollSpeaker(withAudio:named:)`
|
||||
/// - `LSEENDDiarizer.enrollSpeaker(withSamples:named:)`
|
||||
final class SpeakerEnrollmentTests: XCTestCase {
|
||||
private static let fixtureSampleRate = 16_000
|
||||
nonisolated(unsafe) private static var cachedFixtureAudioURL: URL?
|
||||
nonisolated(unsafe) private static var cachedLseendEngine: LSEENDInferenceHelper?
|
||||
|
||||
private func loadSortformerModelsForTest(config: SortformerConfig) async throws -> SortformerModels {
|
||||
@@ -163,11 +160,13 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let diarizer = SortformerDiarizer(config: config)
|
||||
let models = try await loadSortformerModelsForTest(config: config)
|
||||
diarizer.initialize(models: models)
|
||||
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
|
||||
let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
|
||||
|
||||
XCTAssertNotNil(speaker)
|
||||
try XCTSkipIf(
|
||||
speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
|
||||
XCTAssertEqual(speaker?.name, "Alice")
|
||||
XCTAssertEqual(diarizer.numFramesProcessed, 0)
|
||||
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
|
||||
@@ -184,14 +183,17 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let diarizer = SortformerDiarizer(config: config)
|
||||
let models = try await loadSortformerModelsForTest(config: config)
|
||||
diarizer.initialize(models: models)
|
||||
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
let liveAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0)
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
let liveAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0)
|
||||
|
||||
let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
|
||||
XCTAssertNotNil(speaker)
|
||||
try XCTSkipIf(
|
||||
speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
|
||||
|
||||
var update: DiarizerTimelineUpdate?
|
||||
for chunk in chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) {
|
||||
for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) {
|
||||
diarizer.addAudio(chunk)
|
||||
if let next = try diarizer.process() {
|
||||
update = next
|
||||
@@ -217,17 +219,21 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let diarizer = SortformerDiarizer(config: config)
|
||||
let models = try await loadSortformerModelsForTest(config: config)
|
||||
diarizer.initialize(models: models)
|
||||
let speakerAAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
let speakerBAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0)
|
||||
let speakerAAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
let speakerBAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0)
|
||||
|
||||
let speakerA = try diarizer.enrollSpeaker(withAudio: speakerAAudio, named: "Alice")
|
||||
XCTAssertNotNil(speakerA)
|
||||
try XCTSkipIf(
|
||||
speakerA == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
|
||||
|
||||
let stateAfterA = diarizer.state
|
||||
let cachedLengthAfterA = stateAfterA.spkcacheLength + stateAfterA.fifoLength
|
||||
|
||||
let speakerB = try diarizer.enrollSpeaker(withAudio: speakerBAudio, named: "Bob")
|
||||
XCTAssertNotNil(speakerB)
|
||||
try XCTSkipIf(
|
||||
speakerB == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
|
||||
|
||||
let stateAfterB = diarizer.state
|
||||
XCTAssertGreaterThanOrEqual(
|
||||
@@ -250,9 +256,12 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let diarizer = SortformerDiarizer(config: config)
|
||||
let models = try await loadSortformerModelsForTest(config: config)
|
||||
diarizer.initialize(models: models)
|
||||
let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0)
|
||||
|
||||
let firstSpeaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice")
|
||||
try XCTSkipIf(
|
||||
firstSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.")
|
||||
let secondSpeaker = try diarizer.enrollSpeaker(
|
||||
withAudio: enrollmentAudio,
|
||||
named: "Bob",
|
||||
@@ -287,7 +296,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let engine = try await loadLseendEngineForTest()
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
let enrollmentAudio = try fixtureAudio(
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
|
||||
let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
|
||||
@@ -298,7 +307,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
XCTAssertEqual(diarizer.numFramesProcessed, 0)
|
||||
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
|
||||
XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 })
|
||||
XCTAssertTrue(hasActiveLseendSession(diarizer))
|
||||
XCTAssertTrue(diarizer.hasActiveSession)
|
||||
}
|
||||
|
||||
func testLseendEnrollSpeakerFollowedByStreamingProcessingStartsAtFrameZero() async throws {
|
||||
@@ -307,14 +316,15 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let engine = try await loadLseendEngineForTest()
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
let enrollmentAudio = try fixtureAudio(
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
let liveAudio = try fixtureAudio(sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
|
||||
let liveAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
|
||||
|
||||
let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
|
||||
|
||||
var firstUpdate: DiarizerTimelineUpdate?
|
||||
for chunk in chunk(liveAudio, sizes: [977, 1231, 1607]) {
|
||||
for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [977, 1231, 1607]) {
|
||||
if let update = try diarizer.process(samples: chunk) {
|
||||
firstUpdate = update
|
||||
break
|
||||
@@ -338,9 +348,9 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let engine = try await loadLseendEngineForTest()
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
let speakerAAudio = try fixtureAudio(
|
||||
let speakerAAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
let speakerBAudio = try fixtureAudio(
|
||||
let speakerBAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0)
|
||||
|
||||
let speakerA = try diarizer.enrollSpeaker(withSamples: speakerAAudio, named: "Alice")
|
||||
@@ -348,7 +358,7 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
|
||||
XCTAssertEqual(diarizer.numFramesProcessed, 0)
|
||||
XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0)
|
||||
XCTAssertTrue(hasActiveLseendSession(diarizer))
|
||||
XCTAssertTrue(diarizer.hasActiveSession)
|
||||
let expectedNames = Set([speakerA?.name, speakerB?.name].compactMap { $0 })
|
||||
XCTAssertEqual(Set(namedSpeakerNames(in: diarizer.timeline)), expectedNames)
|
||||
}
|
||||
@@ -359,11 +369,12 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
let engine = try await loadLseendEngineForTest()
|
||||
let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly)
|
||||
diarizer.initialize(engine: engine)
|
||||
let enrollmentAudio = try fixtureAudio(
|
||||
let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio(
|
||||
sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0)
|
||||
|
||||
let firstSpeaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice")
|
||||
try XCTSkipIf(firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.")
|
||||
try XCTSkipIf(
|
||||
firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.")
|
||||
let secondSpeaker = try diarizer.enrollSpeaker(
|
||||
withSamples: enrollmentAudio,
|
||||
named: "Bob",
|
||||
@@ -375,95 +386,6 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
XCTAssertEqual(namedSpeakerNames(in: diarizer.timeline), ["Alice"])
|
||||
}
|
||||
|
||||
private func fixtureAudio(sampleRate: Int, startSeconds: Double = 0.0, durationSeconds: Double) throws -> [Float] {
|
||||
let converter = AudioConverter(sampleRate: Double(sampleRate))
|
||||
let audio = try converter.resampleAudioFile(try fixtureAudioFileURL())
|
||||
let startSample = min(audio.count, Int(startSeconds * Double(sampleRate)))
|
||||
let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate)))
|
||||
return Array(audio[startSample..<endSample])
|
||||
}
|
||||
|
||||
private func fixtureAudioFileURL() throws -> URL {
|
||||
if let cached = Self.cachedFixtureAudioURL,
|
||||
FileManager.default.fileExists(atPath: cached.path)
|
||||
{
|
||||
return cached
|
||||
}
|
||||
|
||||
let url = FileManager.default.temporaryDirectory
|
||||
.appendingPathComponent("speaker-enrollment-fixture-\(UUID().uuidString)")
|
||||
.appendingPathExtension("wav")
|
||||
try writeFixtureAudio(to: url)
|
||||
Self.cachedFixtureAudioURL = url
|
||||
return url
|
||||
}
|
||||
|
||||
private func writeFixtureAudio(to url: URL) throws {
|
||||
let sampleRate = Double(Self.fixtureSampleRate)
|
||||
let samples = makeFixtureSamples(sampleRate: sampleRate)
|
||||
let format = AVAudioFormat(
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
sampleRate: sampleRate,
|
||||
channels: 1,
|
||||
interleaved: false
|
||||
)!
|
||||
guard
|
||||
let buffer = AVAudioPCMBuffer(
|
||||
pcmFormat: format,
|
||||
frameCapacity: AVAudioFrameCount(samples.count)
|
||||
)
|
||||
else {
|
||||
XCTFail("Failed to allocate fixture audio buffer")
|
||||
return
|
||||
}
|
||||
|
||||
buffer.frameLength = AVAudioFrameCount(samples.count)
|
||||
samples.withUnsafeBufferPointer { source in
|
||||
guard let destination = buffer.floatChannelData?[0] else { return }
|
||||
destination.update(from: source.baseAddress!, count: samples.count)
|
||||
}
|
||||
|
||||
let file = try AVAudioFile(
|
||||
forWriting: url,
|
||||
settings: format.settings,
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
interleaved: false
|
||||
)
|
||||
try file.write(from: buffer)
|
||||
}
|
||||
|
||||
private func makeFixtureSamples(sampleRate: Double) -> [Float] {
|
||||
let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [
|
||||
(1.0, 0.20, 220),
|
||||
(0.35, 0.00, 0),
|
||||
(1.1, 0.32, 330),
|
||||
(0.25, 0.00, 0),
|
||||
(1.0, 0.28, 180),
|
||||
(0.40, 0.00, 0),
|
||||
(1.3, 0.36, 260),
|
||||
(0.30, 0.00, 0),
|
||||
(1.1, 0.24, 410),
|
||||
]
|
||||
|
||||
var output: [Float] = []
|
||||
for (duration, amplitude, frequency) in segments {
|
||||
let frameCount = Int(duration * sampleRate)
|
||||
guard amplitude > 0, frequency > 0 else {
|
||||
output.append(contentsOf: repeatElement(0, count: frameCount))
|
||||
continue
|
||||
}
|
||||
|
||||
for frame in 0..<frameCount {
|
||||
let time = Double(frame) / sampleRate
|
||||
let envelope = Float(min(1.0, time * 12.0)) * Float(min(1.0, (duration - time) * 12.0))
|
||||
let carrier = sin(2.0 * Double.pi * frequency * time)
|
||||
let harmonic = 0.35 * sin(2.0 * Double.pi * frequency * 2.03 * time)
|
||||
output.append(Float((carrier + harmonic) * Double(amplitude * envelope)))
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
private func namedSpeakerIndices(in timeline: DiarizerTimeline) -> [Int] {
|
||||
timeline.speakers.values
|
||||
.filter { $0.name != nil }
|
||||
@@ -476,29 +398,4 @@ final class SpeakerEnrollmentTests: XCTestCase {
|
||||
.compactMap(\.name)
|
||||
.sorted()
|
||||
}
|
||||
|
||||
private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] {
|
||||
var chunks: [[Float]] = []
|
||||
var start = 0
|
||||
var index = 0
|
||||
while start < samples.count {
|
||||
let size = sizes[index % sizes.count]
|
||||
let stop = min(samples.count, start + size)
|
||||
chunks.append(Array(samples[start..<stop]))
|
||||
start = stop
|
||||
index += 1
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
private func hasActiveLseendSession(_ diarizer: LSEENDDiarizer) -> Bool {
|
||||
let mirror = Mirror(reflecting: diarizer)
|
||||
guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else {
|
||||
XCTFail("Expected LS-EEND diarizer to expose _session via reflection")
|
||||
return false
|
||||
}
|
||||
|
||||
let optionalMirror = Mirror(reflecting: sessionValue)
|
||||
return optionalMirror.displayStyle == .optional && optionalMirror.children.count == 1
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user