From 8aa0dfcdac5f64cd52ba52dd4150b8471021811d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 18 Mar 2026 12:51:34 -0400 Subject: [PATCH] fix: clean up diarization test infrastructure (#395) ## Summary - Extract shared fixture helpers into `DiarizationTestFixtures` enum, removing ~200 lines of duplicate code across `LSEENDIntegrationTests` and `SpeakerEnrollmentTests` - Replace fragile `Mirror`-based private state inspection with `internal` `hasActiveSession` property on `LSEENDDiarizerAPI` - Fix non-deterministic `srand48` seed in `SortformerTests` (use constant `42` instead of time-based seed) - Fix asymmetric skip guards in Sortformer enrollment tests (`XCTSkipIf` instead of `XCTAssertNotNil` for host-dependent segments) ## Test plan - [x] `swift build --build-tests` passes - [ ] `swift test --filter SortformerTests` passes - [ ] `swift test --filter LSEENDIntegrationTests` passes - [ ] `swift test --filter SpeakerEnrollmentTests` passes --- Open with Devin --- Documentation/API.md | 136 +------ Documentation/Diarization/LS-EEND.md | 2 +- Documentation/Diarization/Sortformer.md | 46 +-- README.md | 41 +- Scripts/nemo_ami_benchmark/README.md | 2 +- .../Streaming/StreamingEouAsrManager.swift | 6 +- .../Diarizer/DiarizerTimeline.swift | 66 ---- .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 13 +- .../LS-EEND/LSEENDModelInference.swift | 13 +- .../Diarizer/LS-EEND/LSEENDPreprocessor.swift | 12 +- .../Offline/Core/OfflineDiarizerTypes.swift | 6 - .../SortformerDiarizerPipeline.swift | 31 +- .../Diarizer/Sortformer/SortformerTypes.swift | 216 ++-------- Sources/FluidAudio/DownloadUtils.swift | 167 +++++--- Sources/FluidAudio/ModelNames.swift | 50 +-- .../FluidAudio/Shared/ANEMemoryUtils.swift | 91 ++--- .../FluidAudio/Shared/AudioConverter.swift | 24 +- ...rogram.swift => AudioMelSpectrogram.swift} | 2 +- .../Assets/PocketTtsResourceDownloader.swift | 93 +---- .../Commands/DiarizationBenchmark.swift | 4 +- .../Commands/DiarizationBenchmarkUtils.swift | 332 ++++++++++++++++ .../Commands/LSEENDBenchmark.swift | 333 +--------------- .../Commands/LSEENDCommand.swift | 20 +- .../Commands/LSEENDEvaluation.swift | 3 +- .../Commands/SortformerBenchmark.swift | 369 ++---------------- .../Commands/SortformerCommand.swift | 18 +- ...s.swift => AudioMelSpectrogramTests.swift} | 6 +- .../Diarizer/DiarizationTestFixtures.swift | 136 +++++++ .../LS-EEND/LSEENDIntegrationTests.swift | 148 +------ .../Diarizer/LS-EEND/LSEENDMatrixTests.swift | 326 ++++++++++++++++ .../Diarizer/Sortformer/SortformerTests.swift | 39 +- .../Sortformer/SortformerTimelineTests.swift | 10 +- .../Sortformer/SortformerTypesTests.swift | 2 +- .../Diarizer/SpeakerEnrollmentTests.swift | 173 ++------ 34 files changed, 1251 insertions(+), 1685 deletions(-) rename Sources/FluidAudio/Shared/{NeMoMelSpectrogram.swift => AudioMelSpectrogram.swift} (99%) create mode 100644 Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift rename Tests/FluidAudioTests/ASR/Streaming/{NeMoMelSpectrogramTests.swift => AudioMelSpectrogramTests.swift} (96%) create mode 100644 Tests/FluidAudioTests/Diarizer/DiarizationTestFixtures.swift create mode 100644 Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift diff --git a/Documentation/API.md b/Documentation/API.md index bc5f95a1..bdc7ed37 100644 --- a/Documentation/API.md +++ b/Documentation/API.md @@ -75,161 +75,51 @@ Use `OfflineDiarizerManager` when you need offline DER parity or want to run the ### Diarizer Protocol -`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, which provides a unified streaming and offline API. +`SortformerDiarizer` and `LSEENDDiarizer` both conform to the `Diarizer` protocol, providing a unified streaming and offline API. -**Protocol Properties:** -- `isAvailable: Bool` — Whether the model is loaded and ready -- `numFramesProcessed: Int` — Confirmed frames processed so far -- `targetSampleRate: Int?` — Model's expected audio sample rate in Hz -- `modelFrameHz: Double?` — Output frame rate in Hz (frames per second) -- `numSpeakers: Int?` — Number of real speaker output tracks -- `timeline: DiarizerTimeline` — Accumulated diarization results +**Streaming:** `addAudio(_:sourceSampleRate:)` → `process()` → read `timeline`. Convenience `process(samples:sourceSampleRate:)` combines both steps. Returns `DiarizerTimelineUpdate?` (`nil` when not enough audio has accumulated). -**Streaming:** -- `addAudio(_ samples: C, sourceSampleRate: Double?) throws` — Buffer audio for processing; pass a non-nil `sourceSampleRate` to resample on the fly -- `process() throws -> DiarizerTimelineUpdate?` — Run inference on buffered audio; returns `nil` if not enough audio has accumulated -- `process(samples: C, sourceSampleRate: Double?) throws -> DiarizerTimelineUpdate?` — Convenience combining `addAudio` + `process` in one call +**Offline:** `processComplete(_:sourceSampleRate:...)` or `processComplete(audioFileURL:...)` to process a full recording in one call. -**Offline:** -- `processComplete(_ samples: C, sourceSampleRate:, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Process a complete audio buffer in one call -- `processComplete(audioFileURL: URL, keepingEnrolledSpeakers:, finalizeOnCompletion:, progressCallback:) throws -> DiarizerTimeline` — Read, resample, and process an audio file end-to-end +**Speaker Enrollment:** `enrollSpeaker(withAudio:sourceSampleRate:named:...)` feeds known-speaker audio before streaming to label a slot. -**Speaker Enrollment:** -- `enrollSpeaker(withAudio samples: C, sourceSampleRate:, named:, overwritingAssignedSpeakerName:) throws -> DiarizerSpeaker?` — Feed audio of a known speaker before streaming begins; warms model state and labels that speaker's slot for subsequent `process()` calls - -**Lifecycle:** -- `reset()` — Clear all streaming state (session, buffers, timeline) while keeping the model loaded -- `cleanup()` — Release all resources including the loaded model +**Lifecycle:** `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything. --- -### DiarizerTimeline +### DiarizerTimeline & DiarizerSpeaker -Holds accumulated streaming predictions and derived speaker segments. Returned by `Diarizer.timeline` and `processComplete(...)`. +`DiarizerTimeline` accumulates per-frame speaker probabilities and derives `DiarizerSpeaker` segments. Each speaker has `finalizedSegments` (confirmed) and `tentativeSegments` (may be revised). Segments expose `startTime`, `endTime`, `duration`, and `isFinalized`. -**Key Properties:** -- `config: DiarizerTimelineConfig` — Post-processing configuration used to build segments -- `speakers: [Int: DiarizerSpeaker]` — Speaker slots keyed by output track index -- `finalizedPredictions: [Float]` — Flat `[frames × numSpeakers]` array of finalized per-frame probabilities -- `tentativePredictions: [Float]` — Same layout; frames still within the right-context window that may be revised -- `numFinalizedFrames: Int` — Count of finalized frames -- `numTentativeFrames: Int` — Count of tentative frames -- `finalizedDuration: Float` — Duration in seconds of finalized audio -- `hasSegments: Bool` — Whether any speaker has at least one segment - -**Mutation:** -- `addChunk(_ chunk: DiarizerChunkResult) throws -> DiarizerTimelineUpdate` — Append new predictions and rebuild segments; called internally by the diarizer -- `rebuild(finalizedPredictions:tentativePredictions:keepingSpeakers:isComplete:) throws` — Replace all predictions from scratch (used by offline processing) -- `reset(keepingSpeakers:)` / `reset(keepingSpeakersWhere:)` — Clear segments and optionally preserve named speakers or speaker metadata -- `finalize()` — Promote all tentative segments to finalized - -**`DiarizerTimelineConfig`** — Shared configuration used by both diarizers: -| Parameter | Default | Description | -|---|---|---| -| `numSpeakers` | model-specific | Number of speaker output tracks | -| `frameDurationSeconds` | model-specific | Duration of one output frame | -| `onsetThreshold` | 0.5 | Probability threshold to begin a speech segment | -| `offsetThreshold` | 0.5 | Probability threshold to end a speech segment | -| `onsetPadFrames` | 0 | Frames prepended to each segment onset | -| `offsetPadFrames` | 0 | Frames appended to each segment offset | -| `minFramesOn` | 0 | Minimum segment length; shorter segments are dropped | -| `minFramesOff` | 0 | Minimum gap; shorter silences are closed | -| `maxStoredFrames` | nil | Rolling window cap on finalized frames (nil = unlimited) | - ---- - -### DiarizerSpeaker - -Represents a single speaker track within a `DiarizerTimeline`. - -**Key Properties:** -- `id: UUID` — Stable identity across resets -- `index: Int` — Slot index in the diarizer output (0-based) -- `name: String?` — Optional display name (set via enrollment or manually) -- `finalizedSegments: [DiarizerSegment]` — Confirmed speech segments -- `tentativeSegments: [DiarizerSegment]` — Speculative segments within the right-context window -- `hasSegments: Bool` — Whether any finalized or tentative segments exist -- `numSpeechFrames: Int` — Total frames spanned by all segments (finalized + tentative) -- `speechDuration: Float` — Total speech duration in seconds - -**`DiarizerSegment`** — A single time-range for one speaker: -- `startFrame / endFrame: Int` — Frame indices (convert using `frameDurationSeconds`) -- `startTime / endTime: Float` — Seconds -- `duration: Float` — Segment length in seconds -- `isFinalized: Bool` — Whether the segment has been confirmed +**`DiarizerTimelineConfig`** controls post-processing (onset/offset thresholds default to 0.5, min segment/gap duration, optional rolling window cap). Both diarizers accept this at init. --- ### SortformerDiarizer -End-to-end streaming diarization using NVIDIA's Sortformer model. Tracks **4 fixed speaker slots**. +Streaming diarization using NVIDIA's Sortformer. 4 fixed speaker slots, 16 kHz input, 80 ms frame duration. -- **Sample rate:** 16 kHz -- **Frame duration:** 80 ms (12.5 Hz output) -- **Streaming latency:** ~0.64 s (`default` config) or ~1.04 s (`nvidiaLowLatency` configs) -- **Accuracy:** 31.7% DER on AMI SDM (`nvidiaHighLatencyV2_1`; other configs untested) - -**Initialization:** ```swift -// Preferred: download and compile model automatically let diarizer = SortformerDiarizer(config: .default, timelineConfig: .sortformerDefault) try await diarizer.initialize(mainModelPath: modelURL) - -// Or with pre-loaded models -diarizer.initialize(models: sortformerModels) ``` -**`SortformerConfig` Presets:** - -| Preset | Latency | Notes | -|---|---|---| -| `.default` / `.fastestV2_1` | 1.04 s | Fastest inference speed and update rate. | -| `.fastestV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights. | -| `.nvidiaLowLatencyV2_1` | 1.04 s | Uses more context than `fastest`; Improvement is minimal | -| `.nvidiaLowLatencyV2_0` | 1.04 s | Uses NVIDIA's Sortformer v2 weights | -| `.nvidiaHighLatencyV2_1` | 30.4 s | 31.7% DER on AMI SDM | -| `.nvidiaHighLatencyV2_0` | 30.4 s | Uses NVIDIA's Sortformer v2 weights | - -All streaming methods are defined by the `Diarizer` protocol above. Additionally: -- `state: SortformerStreamingState` — Live speaker cache and FIFO queue state (for diagnostics) -- `config: SortformerConfig` — The configuration this instance was created with +**Config presets:** `.default` / `.fastV2_1` (1.04 s latency), `.balancedV2_1` (1.04 s, 20.6% DER on AMI SDM), `.highContextV2_1` (30.4 s latency). v2 variants also available. --- ### LSEENDDiarizer -End-to-end streaming diarization using LS-EEND (Linear Streaming End-to-End Neural Diarization). Supports a **variable number of speaker slots** depending on the model variant. +Streaming diarization using LS-EEND. Variable speaker slots, 8 kHz input, 100 ms frame duration, 20.7% DER on AMI SDM. -- **Sample rate:** 8 kHz -- **Frame duration:** 100 ms (10 Hz output) -- **Accuracy:** 20.7% DER on AMI SDM (AMI variant) -- **Variants:** `LSEENDVariant` (`LSEENDModelDescriptor.LSEENDVariant`) - -**Initialization:** ```swift -// Auto-download from HuggingFace let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) try await diarizer.initialize(variant: .dihard3) - -// Or with an explicit descriptor -let descriptor = try await LSEENDModelDescriptor.loadFromHuggingFace(variant: .dihard3) -try diarizer.initialize(descriptor: descriptor) ``` -**LS-EEND–Specific Properties:** -- `computeUnits: MLComputeUnits` — CoreML compute target (`.cpuOnly` is typically fastest) -- `streamingLatencySeconds: Double?` — Minimum audio required before first output frame -- `decodeMaxSpeakers: Int?` — Total output slots including internal boundary tracks -- `timelineConfig: DiarizerTimelineConfig` — Active post-processing configuration +**Variants:** ami, callhome, dihard2, dihard3 (via `LSEENDModelDescriptor.loadFromHuggingFace(variant:)`). -**Additional Method:** -- `finalizeSession() throws -> DiarizerChunkResult?` — Flush pending audio and finalize the timeline; call at end of a stream before reading the final timeline - -**`LSEENDModelDescriptor`:** -- `LSEENDModelDescriptor.loadFromHuggingFace(variant:cacheDirectory:computeUnits:) async throws -> LSEENDModelDescriptor` — Download and cache all model files; returns a descriptor ready for `initialize(descriptor:)` -- `init(variant:modelURL:metadataURL:)` — Construct from local paths if already cached - -All streaming and offline methods are defined by the `Diarizer` protocol above. +Call `finalizeSession()` at end-of-stream to flush pending audio before reading the final timeline. ## Voice Activity Detection diff --git a/Documentation/Diarization/LS-EEND.md b/Documentation/Diarization/LS-EEND.md index 9dc3a181..051cfaec 100644 --- a/Documentation/Diarization/LS-EEND.md +++ b/Documentation/Diarization/LS-EEND.md @@ -312,7 +312,7 @@ let result: LSEENDInferenceResult = try engine.infer(audioFileURL: url) let session = try engine.createSession(inputSampleRate: engine.targetSampleRate) // Or with a caller-owned mel spectrogram (for thread isolation) -let mel = NeMoMelSpectrogram(...) +let mel = AudioMelSpectrogram(...) let session = try engine.createSession(inputSampleRate: engine.targetSampleRate, melSpectrogram: mel) ``` diff --git a/Documentation/Diarization/Sortformer.md b/Documentation/Diarization/Sortformer.md index d8eaf83f..dc460001 100644 --- a/Documentation/Diarization/Sortformer.md +++ b/Documentation/Diarization/Sortformer.md @@ -38,7 +38,7 @@ Audio (16kHz) → Mel Spectrogram → CoreML Model → Speaker Probabilities The pipeline consists of: -1. **Mel Spectrogram** (`NeMoMelSpectrogram`): Converts raw audio to 128-bin mel features +1. **Mel Spectrogram** (`AudioMelSpectrogram`): Converts raw audio to 128-bin mel features 2. **CoreML Model** (`DiarizerInference`): Combined encoder + attention + head 3. **Streaming State** (`SortformerStreamingState`): Maintains speaker cache and FIFO queue 4. **Post-processing** (`SortformerTimeline`): Converts probabilities to speaker segments @@ -79,8 +79,8 @@ FIFO Queue Role: | Config | `fifoLen` | Effect | |--------|-----------|--------| | Default | 40 | Smaller memory, faster compression cycles | -| NVIDIA Low | 188 | Larger context before compression | -| NVIDIA High | 40 | Same as default | +| Balanced | 188 | Larger context before compression | +| High Context | 40 | Same as default | When `fifoLen + newChunkFrames > fifoLen` capacity, frames are popped from FIFO and either: 1. Added to speaker cache (if speaker was active) @@ -105,8 +105,8 @@ Chunk with Context: | Config | `rightContext` | Look-ahead | Latency Impact | |--------|----------------|------------|----------------| | Default | 7 | 7 × 80ms = 560ms | Low latency | -| NVIDIA Low | 7 | 7 × 80ms = 560ms | Low latency | -| NVIDIA High | 40 | 40 × 80ms = 3.2s | High latency, better quality | +| Balanced | 7 | 7 × 80ms = 560ms | Low latency | +| High Context | 40 | 40 × 80ms = 3.2s | High latency, better quality | **Why Right Context Matters:** @@ -142,7 +142,7 @@ Default config: = 13 × 8 × 0.01 = 1.04 seconds -NVIDIA High Latency config: +High Context config: = (340 + 40) × 8 × 160 / 16000 = 380 × 8 × 0.01 = 30.4 seconds @@ -191,11 +191,11 @@ Defines streaming parameters that must match the CoreML model's static shapes: // Default (~1.04s latency, lowest latency) SortformerConfig.default -// NVIDIA High Latency (30.4s latency, best quality) -SortformerConfig.nvidiaHighLatency +// Balanced (1.04s latency, best quality on AMI SDM) +SortformerConfig.balancedV2_1 -// NVIDIA Low Latency (1.04s latency) -SortformerConfig.nvidiaLowLatency +// High Context (30.4s latency, most context) +SortformerConfig.highContextV2_1 ``` ### Pipeline.swift @@ -375,9 +375,11 @@ public struct SortformerSegment { | Config | Chunk Size | Latency | Quality | |--------|------------|---------|---------| -| `default` | 6 frames | ~1.04s | Good | -| `nvidiaLowLatency` | 6 frames | ~1.04s | Better | -| `nvidiaHighLatency` | 340 frames | ~30.4s | Best | +| `default` / `fastV2_1` | 6 frames | ~1.04s | Good | +| `balancedV2_1` | 6 frames | ~1.04s | Best (20.6% DER on AMI SDM) | +| `highContextV2_1` | 340 frames | ~30.4s | Good (31.7% DER on AMI SDM) | + +> **Note:** v2.1 variants may degrade when many speakers are talking simultaneously. v2 variants (`fastV2`, `balancedV2`, `highContextV2`) are available as alternatives. Latency is determined by: - `chunkLen * subsamplingFactor * melStride / sampleRate` @@ -396,10 +398,10 @@ This preserves the most informative historical context while bounding memory usa ## Post-Processing -`SortformerPostProcessingConfig` controls segment extraction: +`DiarizerTimelineConfig` controls segment extraction: ```swift -let config = SortformerPostProcessingConfig( +let config = DiarizerTimelineConfig( onsetThreshold: 0.5, // Probability to start speech offsetThreshold: 0.5, // Probability to end speech minDurationOn: 0.25, // Min speech segment (seconds) @@ -414,8 +416,8 @@ Three CoreML models are available on HuggingFace: | Variant | File | Config | |---------|------|--------| | Default | `Sortformer.mlmodelc` | `SortformerConfig.default` | -| NVIDIA Low | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.nvidiaLowLatency` | -| NVIDIA High | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.nvidiaHighLatency` | +| Balanced | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.balancedV2_1` | +| High Context | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.highContextV2_1` | **Important:** Each model has baked-in static shapes. You must use the matching configuration. @@ -447,8 +449,8 @@ audioEngine.installTap { buffer in ### Batch Processing ```swift -let diarizer = SortformerDiarizer(config: .nvidiaHighLatency) -let models = try await SortformerModels.loadFromHuggingFace(config: .nvidiaHighLatency) +let diarizer = SortformerDiarizer(config: .highContextV2_1) +let models = try await SortformerModels.loadFromHuggingFace(config: .highContextV2_1) diarizer.initialize(models: models) let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_000) @@ -457,9 +459,9 @@ let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_0 let fileTimeline = try diarizer.processComplete(audioFileURL: audioURL) // Get segments per speaker -for (speakerIndex, segments) in timeline.segments.enumerated() { - for segment in segments { - print("Speaker \(speakerIndex): \(segment.startTime)s - \(segment.endTime)s") +for (index, speaker) in timeline.speakers { + for segment in speaker.finalizedSegments { + print("Speaker \(index): \(segment.startTime)s - \(segment.endTime)s") } } ``` diff --git a/README.md b/README.md index a3c98f88..0634c7d8 100644 --- a/README.md +++ b/README.md @@ -335,25 +335,7 @@ swift run fluidaudiocli diarization-benchmark --mode offline --auto-download \ ### LS-EEND (LongForm Streaming End-to-End Neural Diarization) -State-of-the-art end-to-end streaming diarization with fully local CoreML inference. This is the best default option for online and streaming diarization when you want low-latency speaker activity updates from a single model without a separate clustering pipeline. Multiple exported variants are available for different benchmark domains, and the diarizer supports both streaming and complete-buffer processing. - -Why use LS-EEND: -- Robust to noise and high speaker overlap -- Supports up to 10 speakers in a session -- Generally achieves better benchmark results than Sortformer on CALLHOME, AMI, and DIHARD III -- Frame-by-frame inference with 100ms frames, and 900ms of tentative frames. -- Much more lightweight than Sortformer, and can be faster on CPU than Sortformer on ANE -- Does not suffer from the missed-quiet-speech issues that Sortformer can show -- End-to-end model without separate segmentation, embedding extraction, and clustering stages -- Also works well for complete-buffer inference when you want the same model in offline mode - -Tradeoffs: -- Can pick up background speakers more readily than Sortformer -- Speaker identity stability is usually somewhat weaker than Sortformer -- Speaker indices are expected to follow chronological arrival order, but Sortformer tends to maintain that ordering more reliably -- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline - -See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for model loading and integration details. +End-to-end streaming diarization with CoreML inference. Default choice for online diarization — single model, no clustering pipeline, up to 10 speakers, 100ms frame updates with 900ms tentative preview. Supports both streaming and complete-buffer processing. See [Documentation/Diarization/GettingStarted.md](Documentation/Diarization/GettingStarted.md) for details. ```swift import FluidAudio @@ -373,26 +355,9 @@ Task { ### Sortformer (End-to-End Neural Diarization) -End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). This is the secondary streaming/online diarizer behind LS-EEND. It offers nearly the same live-update workflow and API flexibility, but trades LS-EEND's stronger benchmark results and higher speaker capacity for better identity stability. No separate VAD, segmentation, or clustering needed. Limited to 4 speakers and does not remember speakers across recordings. Licensed under NVIDIA Open Model License (no restrictions). +End-to-end neural diarization using [NVIDIA's Sortformer](https://arxiv.org/abs/2409.06656). Secondary streaming diarizer — trades LS-EEND's higher speaker capacity and benchmark results for better speaker identity stability. Limited to 4 speakers. No separate VAD, segmentation, or clustering needed. Licensed under NVIDIA Open Model License. -Why use Sortformer: -- Simple streaming pipeline with low latency -- Supports live-update flexibility like LS-EEND -- Handles overlapping speech without an external clustering stage -- Better speaker identity stability than LS-EEND -- Ignores speech from background conversations -- Good choice when the 4-speaker limit is acceptable - -Tradeoffs: -- Limited to 4 unique speakers -- Updates are performed less frequently than for LS-EEND -- Can get overwhelmed when there are too many background speakers -- Can miss speech when the audio is too quiet -- Supports speaker pre-enrollment, but it may be less reliable than the WeSpeaker pipeline - -Like LS-EEND, Sortformer supports ultra-low-latency inference with 0.5s updates, 0.5s preview frames, and a fixed 1.0s latency before finalized predictions. Both models emit their results into a `DiarizerTimeline`, so you do not need to manage incremental diarization state externally. - -See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage, comparison with Pyannote, streaming config, and architecture details. +Both LS-EEND and Sortformer emit results into a `DiarizerTimeline` with ultra-low-latency updates. See [Documentation/Diarization/Sortformer.md](Documentation/Diarization/Sortformer.md) for usage and comparison. ### Streaming/Online Speaker Diarization (Pyannote) diff --git a/Scripts/nemo_ami_benchmark/README.md b/Scripts/nemo_ami_benchmark/README.md index 06ba73f3..c0ee01fb 100644 --- a/Scripts/nemo_ami_benchmark/README.md +++ b/Scripts/nemo_ami_benchmark/README.md @@ -82,7 +82,7 @@ HF_TOKEN="your_token" python nemo_ami_benchmark.py --output results.json ### High-Latency Streaming Config -These settings match the Swift `SortformerConfig.nvidiaHighLatency`: +These settings match the Swift `SortformerConfig.highContextV2_1`: | Parameter | Value | Description | |-----------|-------|-------------| diff --git a/Sources/FluidAudio/ASR/Streaming/StreamingEouAsrManager.swift b/Sources/FluidAudio/ASR/Streaming/StreamingEouAsrManager.swift index c7c2be6f..bc14d027 100644 --- a/Sources/FluidAudio/ASR/Streaming/StreamingEouAsrManager.swift +++ b/Sources/FluidAudio/ASR/Streaming/StreamingEouAsrManager.swift @@ -172,7 +172,7 @@ public actor StreamingEouAsrManager { private var rnntDecoder: RnntDecoder? private let audioConverter = AudioConverter() private var tokenizer: Tokenizer? - private let melProcessor = NeMoMelSpectrogram() // Native Swift mel spectrogram + private let melProcessor = AudioMelSpectrogram() // Native Swift mel spectrogram // Configuration - now based on chunkSize public let chunkSize: StreamingChunkSize @@ -247,7 +247,7 @@ public actor StreamingEouAsrManager { public func loadModels(modelDir: URL) async throws { logger.info("Loading CoreML models from \(modelDir.path)...") - // No longer loading preprocessor - using native Swift NeMoMelSpectrogram instead + // No longer loading preprocessor - using native Swift AudioMelSpectrogram instead self.streamingEncoder = try await MLModel.load( contentsOf: modelDir.appendingPathComponent("streaming_encoder.mlmodelc"), configuration: self.configuration ) @@ -450,7 +450,7 @@ public actor StreamingEouAsrManager { let mel = try MLMultiArray(shape: [1, 128, NSNumber(value: numFrames)], dataType: .float32) let melPtr = mel.dataPointer.bindMemory(to: Float.self, capacity: mel.count) - // NeMoMelSpectrogram returns [nMels, T] row-major (mel bin, then time) + // AudioMelSpectrogram returns [nMels, T] row-major (mel bin, then time) // CoreML expects [1, 128, T] which is the same layout melPtr.update(from: melFlat, count: melFlat.count) diff --git a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift index cc468b78..523c968a 100644 --- a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift +++ b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift @@ -648,16 +648,6 @@ public final class DiarizerTimeline { queue.sync { _tentativePredictions } } - /// Total number of frames (finalized + tentative) - @available( - *, deprecated, - message: "`numFrames` now includes tentative frames. Use 'numFinalizedFrames' for only finalized frames.", - renamed: "numFinalizedFrames" - ) - public var numFrames: Int { - queue.sync { _numFinalizedFrames + _tentativePredictions.count / config.numSpeakers } - } - /// Total number of finalized frames public var numFinalizedFrames: Int { queue.sync { _numFinalizedFrames } @@ -677,31 +667,11 @@ public final class DiarizerTimeline { speakers.values.contains(where: \.hasSegments) } - /// Duration of all predictions in seconds - @available( - *, deprecated, - message: "`duration` now includes tentative frames. Use 'finalizedDuration' for only finalized frames.", - renamed: "finalizedDuration" - ) - public var duration: Float { - Float(numFrames) * config.frameDurationSeconds - } - /// Duration of finalized predictions in seconds public var finalizedDuration: Float { Float(numFinalizedFrames) * config.frameDurationSeconds } - /// Duration of tentative predictions in seconds - @available( - *, deprecated, - message: "tentativeDuration now excludes finalized frames. Use 'duration' for the full timeline duration.", - renamed: "duration" - ) - public var tentativeDuration: Float { - Float(numTentativeFrames) * config.frameDurationSeconds - } - private var _finalizedPredictions: [Float] = [] private var _tentativePredictions: [Float] = [] private var _speakers: [Int: DiarizerSpeaker] = [:] @@ -1106,42 +1076,6 @@ public final class DiarizerTimeline { } } -extension DiarizerTimeline { - @available(*, deprecated, renamed: "finalizedPredictions") - public var framePredictions: [Float] { finalizedPredictions } - - @available( - *, deprecated, - message: "Use Timeline.speakers[index].confirmedSegments to access a speaker's confirmed segments." - ) - public var segments: [[DiarizerSegment]] { - queue.sync { - var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers) - for (index, speaker) in _speakers { - result[index] = speaker.finalizedSegments - } - return result - } - } - - @available( - *, deprecated, - message: "Use Timeline.speakers[index].tentativeSegments to access a speaker's tentative segments." - ) - public var tentativeSegments: [[DiarizerSegment]] { - queue.sync { - var result: [[DiarizerSegment]] = Array(repeating: [], count: config.numSpeakers) - for (index, speaker) in _speakers { - result[index] = speaker.tentativeSegments - } - return result - } - } - - @available(*, deprecated, renamed: "numTentativeFrames") - public var numTentative: Int { numTentativeFrames } -} - // MARK: - Timeline Update public struct DiarizerTimelineUpdate: Sendable { diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index a054a66e..0a9aaec9 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -80,11 +80,18 @@ public final class LSEENDDiarizer: Diarizer { return _engine?.decodeMaxSpeakers } + /// Whether a streaming session is currently active. + var hasActiveSession: Bool { + lock.lock() + defer { lock.unlock() } + return _session != nil + } + // MARK: - Private State private var _engine: LSEENDInferenceHelper? private var _session: LSEENDStreamingSession? - private var _melSpectrogram: NeMoMelSpectrogram? + private var _melSpectrogram: AudioMelSpectrogram? private var _timeline: DiarizerTimeline private var _numFramesProcessed: Int = 0 private var _timelineConfig: DiarizerTimelineConfig @@ -737,8 +744,8 @@ public final class LSEENDDiarizer: Diarizer { } /// Create a new mel spectrogram instance owned by this diarizer. - private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> NeMoMelSpectrogram { - NeMoMelSpectrogram( + private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> AudioMelSpectrogram { + AudioMelSpectrogram( sampleRate: featureConfig.sampleRate, nMels: featureConfig.nMels, nFFT: featureConfig.nFFT, diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift index 8d653776..69afd6a3 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift @@ -54,7 +54,7 @@ private final class LSEENDInferenceSharedResources { let modelFrameHz: Double let streamingLatencySeconds: Double let decodeMaxSpeakers: Int - let melSpectrogram: NeMoMelSpectrogram + let melSpectrogram: AudioMelSpectrogram let offlineFeatureExtractor: LSEENDOfflineFeatureExtractor // Preallocated ANE-aligned input arrays reused across predictStep calls @@ -77,7 +77,7 @@ private final class LSEENDInferenceSharedResources { modelFrameHz = metadata.frameHz streamingLatencySeconds = metadata.streamingLatencySeconds decodeMaxSpeakers = metadata.maxNspks - melSpectrogram = NeMoMelSpectrogram( + melSpectrogram = AudioMelSpectrogram( sampleRate: featureConfig.sampleRate, nMels: featureConfig.nMels, nFFT: featureConfig.nFFT, @@ -150,7 +150,7 @@ public final class LSEENDInferenceHelper { /// Maximum number of speaker slots in the model output (including boundary tracks). public var decodeMaxSpeakers: Int { sharedResources.decodeMaxSpeakers } - fileprivate var melSpectrogram: NeMoMelSpectrogram { sharedResources.melSpectrogram } + fileprivate var melSpectrogram: AudioMelSpectrogram { sharedResources.melSpectrogram } private var offlineFeatureExtractor: LSEENDOfflineFeatureExtractor { sharedResources.offlineFeatureExtractor } private let lock = NSLock() @@ -191,8 +191,9 @@ public final class LSEENDInferenceHelper { /// - inputSampleRate: Must match ``targetSampleRate``. /// - melSpectrogram: A mel spectrogram instance owned by the caller. /// - Returns: A session that accepts audio via ``LSEENDStreamingSession/pushAudio(_:)``. - public func createSession(inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram) throws -> LSEENDStreamingSession - { + public func createSession( + inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram + ) throws -> LSEENDStreamingSession { try LSEENDStreamingSession(engine: self, inputSampleRate: inputSampleRate, melSpectrogram: melSpectrogram) } @@ -472,7 +473,7 @@ public final class LSEENDStreamingSession { fileprivate var emittedFrames = 0 fileprivate init( - engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: NeMoMelSpectrogram? = nil + engine: LSEENDInferenceHelper, inputSampleRate: Int, melSpectrogram: AudioMelSpectrogram? = nil ) throws { guard inputSampleRate == engine.targetSampleRate else { throw LSEENDError.unsupportedAudio( diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift index 8ed8fa82..24c1ce09 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDPreprocessor.swift @@ -45,8 +45,8 @@ public struct LSEENDFeatureConfig: Sendable, Hashable { } } -private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpectrogram { - NeMoMelSpectrogram( +private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> AudioMelSpectrogram { + AudioMelSpectrogram( sampleRate: config.sampleRate, nMels: config.nMels, nFFT: config.nFFT, @@ -70,14 +70,14 @@ private func createMelSpectrogram(for config: LSEENDFeatureConfig) -> NeMoMelSpe /// For incremental processing, use ``LSEENDStreamingFeatureExtractor`` instead. public final class LSEENDOfflineFeatureExtractor { private let config: LSEENDFeatureConfig - private let spectrogram: NeMoMelSpectrogram + private let spectrogram: AudioMelSpectrogram /// Creates an offline feature extractor. /// /// - Parameters: /// - metadata: Model metadata from which feature parameters are derived. /// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`. - public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) { + public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) { let featureConfig = LSEENDFeatureConfig(metadata: metadata) config = featureConfig self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig) @@ -182,7 +182,7 @@ public final class LSEENDOfflineFeatureExtractor { /// - Important: This class is **not** thread-safe. All calls must be serialized externally. public final class LSEENDStreamingFeatureExtractor { private let config: LSEENDFeatureConfig - private let spectrogram: NeMoMelSpectrogram + private let spectrogram: AudioMelSpectrogram private var audioBuffer: [Float] = [] private var audioStartSample = 0 @@ -201,7 +201,7 @@ public final class LSEENDStreamingFeatureExtractor { /// - Parameters: /// - metadata: Model metadata from which feature parameters are derived. /// - spectrogram: Optional pre-configured mel spectrogram; one is created if `nil`. - public init(metadata: LSEENDModelMetadata, spectrogram: NeMoMelSpectrogram? = nil) { + public init(metadata: LSEENDModelMetadata, spectrogram: AudioMelSpectrogram? = nil) { let featureConfig = LSEENDFeatureConfig(metadata: metadata) config = featureConfig self.spectrogram = spectrogram ?? createMelSpectrogram(for: featureConfig) diff --git a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift index 5c36308f..0a2f211d 100644 --- a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift @@ -409,12 +409,6 @@ public struct OfflineDiarizerConfig: Sendable { set { embedding.excludeOverlap = newValue } } - @available(*, deprecated, renamed: "embeddingExcludeOverlap") - public var shouldExcludeOverlaps: Bool { - get { embeddingExcludeOverlap } - set { embeddingExcludeOverlap = newValue } - } - public var minSegmentDuration: Double { get { embedding.minSegmentDurationSeconds } set { embedding.minSegmentDurationSeconds = newValue } diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 21f362ab..2c499304 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -65,7 +65,7 @@ public final class SortformerDiarizer: Diarizer { private var _models: SortformerModels? // Native mel spectrogram (used when useNativePreprocessing is enabled) - private let melSpectrogram = NeMoMelSpectrogram() + private let melSpectrogram = AudioMelSpectrogram() // Audio buffering private var audioBuffer: [Float] = [] @@ -80,17 +80,6 @@ public final class SortformerDiarizer: Diarizer { // MARK: - Initialization - @available( - *, deprecated, renamed: "init(config:timelineConfig:)", - message: "Use `timelineConfig` instead of `postProcessingConfig`." - ) - public convenience init( - config: SortformerConfig = .default, - postProcessingConfig: DiarizerTimelineConfig = .default(numSpeakers: 4, frameDurationSeconds: 0.08) - ) { - self.init(config: config, timelineConfig: postProcessingConfig) - } - public init( config: SortformerConfig = .default, timelineConfig: DiarizerTimelineConfig = .sortformerDefault @@ -103,12 +92,6 @@ public final class SortformerDiarizer: Diarizer { self._timeline = DiarizerTimeline(config: timelineConfig) } - /// Backward-compatible initializer accepting the legacy SortformerPostProcessingConfig. - @available(*, deprecated, message: "Use init(config:timelineConfig:) with DiarizerTimelineConfig instead") - public convenience init(config: SortformerConfig = .default, postProcessingConfig: SortformerPostProcessingConfig) { - self.init(config: config, timelineConfig: postProcessingConfig.toDiarizerConfig()) - } - /// Initialize with CoreML models (combined pipeline mode). /// /// - Parameters: @@ -399,18 +382,6 @@ public final class SortformerDiarizer: Diarizer { /// /// Convenience method that combines `addAudio()` and `process()`. /// - /// - Parameters: - /// - samples: Audio samples (16kHz mono) - /// - sourceSampleRate: Source audio sample rate - /// - Returns: New chunk results if enough audio was processed - @available(*, deprecated, renamed: "process(samples:)") - public func processSamples( - _ samples: [Float], - sourceSampleRate: Double? = nil - ) throws -> DiarizerTimelineUpdate? { - return try process(samples: samples, sourceSampleRate: sourceSampleRate) - } - /// Process a chunk of audio in one call. /// /// Convenience method that combines `addAudio()` and `process()`. diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift index 01880377..c3451135 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift @@ -121,9 +121,10 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 31 ) - /// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2 weights - public static let `fastestV2` = SortformerConfig( - modelVariant: .fastestV2, + /// Fast config with Sortformer v2 weights (~1.04s latency, smallest context). + /// May handle high-speaker-count scenarios better than v2.1 (v2.1 can degrade when many speakers overlap). + public static let `fastV2` = SortformerConfig( + modelVariant: .fastV2, chunkLen: 6, chunkLeftContext: 1, chunkRightContext: 7, @@ -132,9 +133,10 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 31 ) - /// Configuration matching Gradient Descent's Streaming-Sortformer-Conversion models with Sortformer v2.1 weights - public static let `fastestV2_1` = SortformerConfig( - modelVariant: .fastestV2_1, + /// Fast config with Sortformer v2.1 weights (~1.04s latency, smallest context). + /// - Note: v2.1 may degrade when many speakers are talking simultaneously. + public static let `fastV2_1` = SortformerConfig( + modelVariant: .fastV2_1, chunkLen: 6, chunkLeftContext: 1, chunkRightContext: 7, @@ -143,39 +145,10 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 31 ) - /// Backwards compatible alias for NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights - @available(*, deprecated, renamed: "nvidiaHighLatencyV2_1") - public static let nvidiaHighLatency = nvidiaHighLatencyV2_1 - - /// NVIDIA's 30.4s latency configuration with Sortformer v2 weights - public static let nvidiaHighLatencyV2 = SortformerConfig( - modelVariant: .nvidiaHighLatencyV2, - chunkLen: 340, - chunkLeftContext: 1, - chunkRightContext: 40, - fifoLen: 40, - spkcacheLen: 188, - spkcacheUpdatePeriod: 300 - ) - - /// NVIDIA's 30.4s latency configuration with Sortformer v2.1 weights - public static let nvidiaHighLatencyV2_1 = SortformerConfig( - modelVariant: .nvidiaHighLatencyV2_1, - chunkLen: 340, - chunkLeftContext: 1, - chunkRightContext: 40, - fifoLen: 40, - spkcacheLen: 188, - spkcacheUpdatePeriod: 300 - ) - - /// Backwards compatible alias for NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights - @available(*, deprecated, renamed: "nvidiaLowLatencyV2_1") - public static let nvidiaLowLatency = nvidiaLowLatencyV2_1 - - /// NVIDIA's 1.04s latency configuration with Sortformer v2 weights - public static let nvidiaLowLatencyV2 = SortformerConfig( - modelVariant: .nvidiaLowLatencyV2, + /// Balanced config with Sortformer v2 weights (~1.04s latency, larger FIFO for better quality). + /// 20.57% DER on AMI SDM. May handle high-speaker-count scenarios better than v2.1. + public static let balancedV2 = SortformerConfig( + modelVariant: .balancedV2, chunkLen: 6, chunkLeftContext: 1, chunkRightContext: 7, @@ -184,9 +157,11 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 144 ) - /// NVIDIA's 1.04s latency configuration with Sortformer v2.1 weights (20.57% DER on AMI SDM) - public static let nvidiaLowLatencyV2_1 = SortformerConfig( - modelVariant: .nvidiaLowLatencyV2_1, + /// Balanced config with Sortformer v2.1 weights (~1.04s latency, larger FIFO for better quality). + /// 20.57% DER on AMI SDM. + /// - Note: v2.1 may degrade when many speakers are talking simultaneously. + public static let balancedV2_1 = SortformerConfig( + modelVariant: .balancedV2_1, chunkLen: 6, chunkLeftContext: 1, chunkRightContext: 7, @@ -195,9 +170,33 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 144 ) + /// High-context config with Sortformer v2 weights (~30.4s latency, most context window). + /// May handle high-speaker-count scenarios better than v2.1. + public static let highContextV2 = SortformerConfig( + modelVariant: .highContextV2, + chunkLen: 340, + chunkLeftContext: 1, + chunkRightContext: 40, + fifoLen: 40, + spkcacheLen: 188, + spkcacheUpdatePeriod: 300 + ) + + /// High-context config with Sortformer v2.1 weights (~30.4s latency, most context window). + /// - Note: v2.1 may degrade when many speakers are talking simultaneously. + public static let highContextV2_1 = SortformerConfig( + modelVariant: .highContextV2_1, + chunkLen: 340, + chunkLeftContext: 1, + chunkRightContext: 40, + fifoLen: 40, + spkcacheLen: 188, + spkcacheUpdatePeriod: 300 + ) + /// - Warning: If you don't use one of the default configurations, you must use a local model converted with that configuration. public init( - modelVariant: ModelVariant? = .fastestV2_1, + modelVariant: ModelVariant? = .fastV2_1, chunkLen: Int = 6, chunkLeftContext: Int = 1, chunkRightContext: Int = 7, @@ -239,131 +238,6 @@ public struct SortformerConfig: Sendable { } } -/// Configuration for post-processing Sortformer diarizer predictions -@available(*, deprecated, message: "Use DiarizerTimelineConfig instead", renamed: "DiarizerTimelineConfig") -public struct SortformerPostProcessingConfig { - /// Onset threshold for detecting the beginning and end of a speech - public var onsetThreshold: Float - - /// Offset threshold for detecting the end of a speech - public var offsetThreshold: Float - - /// Adding frames before each speech segment - public var onsetPadFrames: Int - - /// Adding frames after each speech segment - public var offsetPadFrames: Int - - /// Threshold for short speech segment deletion in frames - public var minFramesOn: Int - - /// Threshold for small non-speech deletion in frames - public var minFramesOff: Int - - /// Adding durations before each speech segment - public var onsetPadSeconds: Float { - get { Float(onsetPadFrames) * frameDurationSeconds } - set { onsetPadFrames = Int(round(newValue / frameDurationSeconds)) } - } - - /// Adding durations after each speech segment - public var offsetPadSeconds: Float { - get { Float(offsetPadFrames) * frameDurationSeconds } - set { offsetPadFrames = Int(round(newValue / frameDurationSeconds)) } - } - - /// Threshold for short speech segment deletion (seconds) - public var minDurationOn: Float { - get { Float(minFramesOn) * frameDurationSeconds } - set { minFramesOn = Int(round(newValue / frameDurationSeconds)) } - } - - /// Threshold for small non-speech deletion (seconds) - public var minDurationOff: Float { - get { Float(minFramesOff) * frameDurationSeconds } - set { minFramesOff = Int(round(newValue / frameDurationSeconds)) } - } - - /// Maximum number of predictions to retain - public var maxStoredFrames: Int? = nil - - /// Number of speakers - public let numSpeakers: Int = 4 - - /// Number of speakers - public let frameDurationSeconds: Float = 0.08 - - /// Default configurations - public static var `default`: SortformerPostProcessingConfig { - SortformerPostProcessingConfig( - onsetThreshold: 0.5, - offsetThreshold: 0.5, - onsetPadFrames: 0, - offsetPadFrames: 0, - minFramesOn: 0, - minFramesOff: 0 - ) - } - - public init( - onsetThreshold: Float = 0.5, - offsetThreshold: Float = 0.5, - onsetPadSeconds: Float = 0, - offsetPadSeconds: Float = 0, - minDurationOn: Float = 0, - minDurationOff: Float = 0, - maxStoredFrames: Int? = nil - ) { - self.onsetThreshold = onsetThreshold - self.offsetThreshold = offsetThreshold - self.onsetPadFrames = Int(round(onsetPadSeconds / frameDurationSeconds)) - self.offsetPadFrames = Int(round(offsetPadSeconds / frameDurationSeconds)) - self.minFramesOn = Int(round(minDurationOn / frameDurationSeconds)) - self.minFramesOff = Int(round(minDurationOff / frameDurationSeconds)) - self.maxStoredFrames = maxStoredFrames - } - - public init( - onsetThreshold: Float = 0.5, - offsetThreshold: Float = 0.5, - onsetPadFrames: Int = 0, - offsetPadFrames: Int = 0, - minFramesOn: Int = 0, - minFramesOff: Int = 0, - maxStoredFrames: Int? = nil - ) { - self.onsetThreshold = onsetThreshold - self.offsetThreshold = offsetThreshold - self.onsetPadFrames = onsetPadFrames - self.offsetPadFrames = offsetPadFrames - self.minFramesOn = minFramesOn - self.minFramesOff = minFramesOff - self.maxStoredFrames = maxStoredFrames - } - - /// Convert to the unified DiarizerPostProcessingConfig - public func toDiarizerConfig() -> DiarizerTimelineConfig { - DiarizerTimelineConfig( - numSpeakers: numSpeakers, - frameDurationSeconds: frameDurationSeconds, - onsetThreshold: onsetThreshold, - offsetThreshold: offsetThreshold, - onsetPadFrames: onsetPadFrames, - offsetPadFrames: offsetPadFrames, - minFramesOn: minFramesOn, - minFramesOff: minFramesOff, - maxStoredFrames: maxStoredFrames - ) - } -} - -/// Backward-compatible typealiases — SortformerTimeline is now DiarizerTimeline -@available(*, deprecated, renamed: "DiarizerTimeline") -public typealias SortformerTimeline = DiarizerTimeline - -@available(*, deprecated, renamed: "DiarizerSegment") -public typealias SortformerSegment = DiarizerSegment - // MARK: - Streaming State /// State maintained across streaming chunks for Sortformer diarization. @@ -455,7 +329,7 @@ public struct SortformerFeatureLoader: Sendable { self.startFeat = 0 self.endFeat = 0 - (self.featSeq, self.featLength, self.featSeqLength) = NeMoMelSpectrogram().computeFlatTransposed(audio: audio) + (self.featSeq, self.featLength, self.featSeqLength) = AudioMelSpectrogram().computeFlatTransposed(audio: audio) // numChunks accounts for right context requirement: need endFeat + rc <= featLength // Chunk n has endFeat = (n+1) * chunkLen, so valid when (n+1) * chunkLen + rc <= featLength // numChunks = floor((featLength - rc) / chunkLen) @@ -522,10 +396,6 @@ public struct StreamingUpdateResult: Sendable { } } -/// Backward-compatible typealias — SortformerChunkResult is now DiarizerChunkResult -@available(*, deprecated, renamed: "DiarizerChunkResult") -public typealias SortformerChunkResult = DiarizerChunkResult - // MARK: - Errors public enum SortformerError: Error, LocalizedError { diff --git a/Sources/FluidAudio/DownloadUtils.swift b/Sources/FluidAudio/DownloadUtils.swift index 691ad261..5191aae2 100644 --- a/Sources/FluidAudio/DownloadUtils.swift +++ b/Sources/FluidAudio/DownloadUtils.swift @@ -381,14 +381,15 @@ public class DownloadUtils { let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedFilePath) let request = authorizedRequest(url: fileURL) + let tempFileURL: URL let httpResponse: HTTPURLResponse + if let handler = progressHandler { let baseBytes = completedBytes let fileCount = filesToDownload.count let totalBytesSnapshot = totalBytes - httpResponse = try await downloadWithProgress( + (tempFileURL, httpResponse) = try await downloadWithProgress( request: request, - destinationURL: destPath, onProgress: { bytesWritten, _ in guard totalBytesSnapshot > 0 else { return } let current = baseBytes + bytesWritten @@ -406,41 +407,11 @@ public class DownloadUtils { guard let resp = response as? HTTPURLResponse else { throw HuggingFaceDownloadError.invalidResponse } + tempFileURL = url httpResponse = resp - if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 { - throw HuggingFaceDownloadError.rateLimited( - statusCode: httpResponse.statusCode, - message: "Rate limited while downloading \(file.path)") - } - - guard (200..<300).contains(httpResponse.statusCode) else { - throw HuggingFaceDownloadError.downloadFailed( - path: file.path, - underlying: NSError(domain: "HTTP", code: httpResponse.statusCode) - ) - } - - // Remove existing file if present (handles parallel download race conditions) - if FileManager.default.fileExists(atPath: destPath.path) { - try? FileManager.default.removeItem(at: destPath) - } - try FileManager.default.moveItem(at: url, to: destPath) - completedBytes += Int64(max(0, file.size)) - if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 { - logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files") - } - - // Report completed-file progress - progressHandler?( - DownloadProgress( - fractionCompleted: totalBytes > 0 - ? 0.5 * Double(completedBytes) / Double(totalBytes) - : 0.5 * Double(index + 1) / Double(filesToDownload.count), - phase: .downloading(completedFiles: index + 1, totalFiles: filesToDownload.count) - )) - continue } + // Validate response if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 { throw HuggingFaceDownloadError.rateLimited( statusCode: httpResponse.statusCode, @@ -454,13 +425,18 @@ public class DownloadUtils { ) } + // Move downloaded file to destination + if FileManager.default.fileExists(atPath: destPath.path) { + try? FileManager.default.removeItem(at: destPath) + } + try FileManager.default.moveItem(at: tempFileURL, to: destPath) + completedBytes += Int64(max(0, file.size)) if (index + 1) % 10 == 0 || index == filesToDownload.count - 1 { logger.info("Downloaded \(index + 1)/\(filesToDownload.count) files") } - // Report completed-file progress progressHandler?( DownloadProgress( fractionCompleted: totalBytes > 0 @@ -485,16 +461,17 @@ public class DownloadUtils { /// Download a single file using a delegate to get byte-level progress. /// + /// This is a pure transport helper — the caller is responsible for validating + /// the HTTP status and moving the temporary file to its final destination. + /// /// - Parameters: /// - request: The URLRequest to download. - /// - onProgress: Called with `(bytesWritten, totalBytesExpected)` as data arrives. - /// - Parameter destinationURL: Final destination path for the downloaded file. - /// - Returns: The HTTP response for the completed download. + /// - onProgress: Called with `(totalBytesWritten, totalBytesExpected)` as data arrives. + /// - Returns: The temporary file URL and HTTP response. private static func downloadWithProgress( request: URLRequest, - destinationURL: URL, onProgress: @escaping @Sendable (Int64, Int64) -> Void - ) async throws -> HTTPURLResponse { + ) async throws -> (URL, HTTPURLResponse) { let delegate = DownloadProgressDelegate(onProgress: onProgress) // Dedicated session with delegate — one per download to avoid cross-talk. let session = URLSession( @@ -510,25 +487,107 @@ public class DownloadUtils { throw HuggingFaceDownloadError.invalidResponse } - if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 { - throw HuggingFaceDownloadError.rateLimited( - statusCode: httpResponse.statusCode, - message: "HTTP \(httpResponse.statusCode)" - ) + return (tempURL, httpResponse) + } + + /// Download a specific subdirectory from a HuggingFace repository. + /// + /// Use this for optional model components that aren't part of the required model set + /// (e.g., the Mimi encoder for PocketTTS voice cloning). + /// + /// - Parameters: + /// - repo: The HuggingFace repository. + /// - subdirectory: Path within the repo to download (e.g. `"mimi_encoder.mlmodelc"`). + /// - repoDirectory: Local directory corresponding to the repo root. + /// Files are saved at `repoDirectory/`. + public static func downloadSubdirectory( + _ repo: Repo, + subdirectory: String, + to repoDirectory: URL + ) async throws { + var filesToDownload: [(path: String, size: Int)] = [] + + func listFiles(at path: String) async throws { + let dirURL = try ModelRegistry.apiModels(repo.remotePath, "tree/main/\(path)") + let (dirData, response) = try await fetchWithAuth(from: dirURL) + if let httpResponse = response as? HTTPURLResponse, + httpResponse.statusCode == 429 || httpResponse.statusCode == 503 + { + throw HuggingFaceDownloadError.rateLimited( + statusCode: httpResponse.statusCode, + message: "Rate limited while listing files in \(path)") + } + guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else { + return + } + for item in items { + guard let itemPath = item["path"] as? String, + let itemType = item["type"] as? String + else { continue } + + if itemType == "directory" { + try await listFiles(at: itemPath) + } else if itemType == "file" { + let fileSize = item["size"] as? Int ?? -1 + filesToDownload.append((path: itemPath, size: fileSize)) + } + } } - guard (200..<300).contains(httpResponse.statusCode) else { - throw HuggingFaceDownloadError.downloadFailed( - path: destinationURL.lastPathComponent, - underlying: NSError(domain: "HTTP", code: httpResponse.statusCode) + try await listFiles(at: subdirectory) + logger.info("Found \(filesToDownload.count) files in \(subdirectory)") + + for (index, file) in filesToDownload.enumerated() { + let destPath = repoDirectory.appendingPathComponent(file.path) + + if FileManager.default.fileExists(atPath: destPath.path) { + continue + } + + try FileManager.default.createDirectory( + at: destPath.deletingLastPathComponent(), + withIntermediateDirectories: true ) + + if file.size == 0 { + FileManager.default.createFile(atPath: destPath.path, contents: Data()) + continue + } + + let encodedPath = + file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path + let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedPath) + let request = authorizedRequest(url: fileURL) + + let (tempURL, response) = try await sharedSession.download(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw HuggingFaceDownloadError.invalidResponse + } + + if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 { + throw HuggingFaceDownloadError.rateLimited( + statusCode: httpResponse.statusCode, + message: "Rate limited while downloading \(file.path)") + } + + guard (200..<300).contains(httpResponse.statusCode) else { + throw HuggingFaceDownloadError.downloadFailed( + path: file.path, + underlying: NSError(domain: "HTTP", code: httpResponse.statusCode) + ) + } + + if FileManager.default.fileExists(atPath: destPath.path) { + try? FileManager.default.removeItem(at: destPath) + } + try FileManager.default.moveItem(at: tempURL, to: destPath) + + if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 { + logger.info("Downloaded \(index + 1)/\(filesToDownload.count) \(subdirectory) files") + } } - if FileManager.default.fileExists(atPath: destinationURL.path) { - try? FileManager.default.removeItem(at: destinationURL) - } - try FileManager.default.moveItem(at: tempURL, to: destinationURL) - return httpResponse + logger.info("Downloaded \(subdirectory) from \(repo.folderName)") } /// Fetch a single file from HuggingFace with retry diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 579a10a9..4c989212 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -240,44 +240,44 @@ public enum ModelNames { /// Sortformer streaming diarization model names public enum Sortformer { public enum Variant: CaseIterable, Sendable { - case fastestV2 - case fastestV2_1 - case nvidiaLowLatencyV2 - case nvidiaLowLatencyV2_1 - case nvidiaHighLatencyV2 - case nvidiaHighLatencyV2_1 + case fastV2 + case fastV2_1 + case balancedV2 + case balancedV2_1 + case highContextV2 + case highContextV2_1 public var name: String { switch self { - case .fastestV2: + case .fastV2: return "Sortformer_v2" - case .fastestV2_1: + case .fastV2_1: return "Sortformer_v2.1" - case .nvidiaLowLatencyV2: + case .balancedV2: return "SortformerNvidiaLow_v2" - case .nvidiaLowLatencyV2_1: + case .balancedV2_1: return "SortformerNvidiaLow_v2.1" - case .nvidiaHighLatencyV2: + case .highContextV2: return "SortformerNvidiaHigh_v2" - case .nvidiaHighLatencyV2_1: + case .highContextV2_1: return "SortformerNvidiaHigh_v2.1" } } public var defaultConfiguration: SortformerConfig { switch self { - case .fastestV2: - return .fastestV2 - case .fastestV2_1: - return .fastestV2_1 - case .nvidiaLowLatencyV2: - return .nvidiaLowLatencyV2 - case .nvidiaLowLatencyV2_1: - return .nvidiaLowLatencyV2_1 - case .nvidiaHighLatencyV2: - return .nvidiaHighLatencyV2 - case .nvidiaHighLatencyV2_1: - return .nvidiaHighLatencyV2_1 + case .fastV2: + return .fastV2 + case .fastV2_1: + return .fastV2_1 + case .balancedV2: + return .balancedV2 + case .balancedV2_1: + return .balancedV2_1 + case .highContextV2: + return .highContextV2 + case .highContextV2_1: + return .highContextV2_1 } } @@ -291,7 +291,7 @@ public enum ModelNames { } /// Lowest latency for streaming - public static let defaultVariant: Variant = .fastestV2_1 + public static let defaultVariant: Variant = .fastV2_1 /// Bundle name for a specific variant public static func bundle(for variant: Variant) -> String { diff --git a/Sources/FluidAudio/Shared/ANEMemoryUtils.swift b/Sources/FluidAudio/Shared/ANEMemoryUtils.swift index 60b09c2c..6ba9e085 100644 --- a/Sources/FluidAudio/Shared/ANEMemoryUtils.swift +++ b/Sources/FluidAudio/Shared/ANEMemoryUtils.swift @@ -1,8 +1,6 @@ -import Accelerate import CoreML import Darwin import Foundation -import Metal /// Shared ANE optimization utilities for all ML pipelines public enum ANEMemoryUtils { @@ -120,13 +118,16 @@ public enum ANEMemoryUtils { shape: [NSNumber], strides: [NSNumber]? = nil ) throws -> MLMultiArray { - // Validate bounds + // Validate bounds using stride-aware backing size (accounts for ANE padding) let elementSize = getElementSize(for: array.dataType) - let totalElements = shape.map { $0.intValue }.reduce(1, *) - let bytesNeeded = totalElements * elementSize + let viewStrides = strides ?? calculateOptimalStrides(for: shape) + let viewBackingElements = + shape.isEmpty ? 0 : viewStrides[0].intValue * shape[0].intValue + let sourceBackingElements = + array.shape.isEmpty ? 0 : array.strides[0].intValue * array.shape[0].intValue let byteOffset = offset * elementSize - guard byteOffset + bytesNeeded <= array.count * elementSize else { + guard byteOffset + viewBackingElements * elementSize <= sourceBackingElements * elementSize else { throw ANEMemoryError.invalidShape } @@ -144,34 +145,47 @@ public enum ANEMemoryUtils { /// Stride-aware copy between two MLMultiArrays that may have different stride layouts. /// - /// Copies all logical elements from `source` to `destination` (which must have the same shape). - /// The innermost dimension is copied as a contiguous block (stride-1), while outer dimensions + /// Copies all logical elements from `source` to `destination` (which must have the same shape + /// and data type). The innermost dimension must have stride 1 in both arrays. Outer dimensions /// are iterated respecting each array's strides. public static func strideAwareCopy(from source: MLMultiArray, to destination: MLMultiArray) { let shape = source.shape.map { $0.intValue } let srcStrides = source.strides.map { $0.intValue } let dstStrides = destination.strides.map { $0.intValue } - let srcPtr = source.dataPointer.assumingMemoryBound(to: Float.self) - let dstPtr = destination.dataPointer.assumingMemoryBound(to: Float.self) - let ndim = shape.count guard ndim > 0 else { return } + // Validate shapes and types match + precondition( + source.shape == destination.shape, + "strideAwareCopy: shape mismatch \(source.shape) vs \(destination.shape)" + ) + precondition( + source.dataType == destination.dataType, + "strideAwareCopy: dataType mismatch" + ) + precondition( + srcStrides[ndim - 1] == 1 && dstStrides[ndim - 1] == 1, + "strideAwareCopy: innermost stride must be 1" + ) + + let elementSize = getElementSize(for: source.dataType) + let srcPtr = source.dataPointer + let dstPtr = destination.dataPointer + // If strides match, a single memcpy suffices (fast path). if srcStrides == dstStrides { - // Total backing storage = first-dim stride * first-dim size - let totalBacking = srcStrides[0] * shape[0] - memcpy(dstPtr, srcPtr, totalBacking * MemoryLayout.size) + let totalBytes = srcStrides[0] * shape[0] * elementSize + memcpy(dstPtr, srcPtr, totalBytes) return } - // Innermost dimension length - let innerLen = shape[ndim - 1] + // Innermost dimension byte count + let innerBytes = shape[ndim - 1] * elementSize if ndim == 1 { - // 1-D: just copy innerLen elements (both have stride 1 for innermost) - memcpy(dstPtr, srcPtr, innerLen * MemoryLayout.size) + memcpy(dstPtr, srcPtr, innerBytes) return } @@ -182,16 +196,16 @@ public enum ANEMemoryUtils { var indices = [Int](repeating: 0, count: ndim - 1) for _ in 0...size) + memcpy(dstPtr + dstByteOffset, srcPtr + srcByteOffset, innerBytes) // Increment multi-index (odometer style) var carry = ndim - 2 @@ -206,33 +220,4 @@ public enum ANEMemoryUtils { } } - /// Prefetch memory pages for ANE processing - public static func prefetchForANE(_ array: MLMultiArray) { - let dataPointer = array.dataPointer - let elementSize = getElementSize(for: array.dataType) - let totalBytes = array.count * elementSize - - // Touch first and last cache lines to trigger ANE DMA prefetch - if totalBytes > 0 { - _ = dataPointer.load(as: UInt8.self) - if totalBytes > 1 { - _ = dataPointer.advanced(by: totalBytes - 1).load(as: UInt8.self) - } - } - } -} - -/// Extension for MLMultiArray to add ANE optimization methods -extension MLMultiArray { - - /// Check if this array is ANE-aligned - public var isANEAligned: Bool { - let address = Int(bitPattern: self.dataPointer) - return address % ANEMemoryUtils.aneAlignment == 0 - } - - /// Prefetch this array for ANE processing - public func prefetchForANE() { - ANEMemoryUtils.prefetchForANE(self) - } } diff --git a/Sources/FluidAudio/Shared/AudioConverter.swift b/Sources/FluidAudio/Shared/AudioConverter.swift index a1be538f..c93a603b 100644 --- a/Sources/FluidAudio/Shared/AudioConverter.swift +++ b/Sources/FluidAudio/Shared/AudioConverter.swift @@ -193,23 +193,25 @@ final public class AudioConverter { throw AudioConverterError.failedToCreateSourceFormat } - // Use AVAudioConverter for channel mixing and format conversion + // Use AVAudioConverter for channel mixing (same sample rate, no resampling needed) guard let converter = AVAudioConverter(from: format, to: monoFormat) else { throw AudioConverterError.failedToCreateConverter } - configure(converter: converter) guard let outputBuffer = AVAudioPCMBuffer(pcmFormat: monoFormat, frameCapacity: buffer.frameCapacity) else { throw AudioConverterError.failedToCreateBuffer } - nonisolated(unsafe) var provided = false - nonisolated(unsafe) let capturedBuffer = buffer + let provided = OSAllocatedUnfairLock(initialState: false) let inputBlock: AVAudioConverterInputBlock = { _, status in - if !provided { - provided = true + let wasProvided = provided.withLock { state -> Bool in + if state { return true } + state = true + return false + } + if !wasProvided { status.pointee = .haveData - return capturedBuffer + return buffer } else { status.pointee = .endOfStream return nil @@ -309,8 +311,12 @@ final public class AudioConverter { // but Swift 6 rejects mutation of captured vars in this callback. let provided = OSAllocatedUnfairLock(initialState: false) let inputBlock: AVAudioConverterInputBlock = { _, status in - if !provided.withLock({ $0 }) { - provided.withLock { $0 = true } + let wasProvided = provided.withLock { state -> Bool in + if state { return true } + state = true + return false + } + if !wasProvided { status.pointee = .haveData return buffer } else { diff --git a/Sources/FluidAudio/Shared/NeMoMelSpectrogram.swift b/Sources/FluidAudio/Shared/AudioMelSpectrogram.swift similarity index 99% rename from Sources/FluidAudio/Shared/NeMoMelSpectrogram.swift rename to Sources/FluidAudio/Shared/AudioMelSpectrogram.swift index a77701cc..3c3d56f5 100644 --- a/Sources/FluidAudio/Shared/NeMoMelSpectrogram.swift +++ b/Sources/FluidAudio/Shared/AudioMelSpectrogram.swift @@ -15,7 +15,7 @@ import Foundation /// - center: True with pad_mode='constant' (zero padding) /// - normalize: "NA" (no normalization) /// - dither: 0.0 (disabled for determinism) -public final class NeMoMelSpectrogram { +public final class AudioMelSpectrogram { public enum PaddingMode: Sendable { case center case prePadded diff --git a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift index 72fc2ef2..db44911e 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift @@ -66,94 +66,11 @@ public enum PocketTtsResourceDownloader { /// Download the Mimi encoder model files from HuggingFace. private static func downloadMimiEncoder(to repoDir: URL) async throws { - let modelName = ModelNames.PocketTTS.mimiEncoderFile - let modelDir = repoDir.appendingPathComponent(modelName) - - try FileManager.default.createDirectory(at: modelDir, withIntermediateDirectories: true) - - // List files in the mimi_encoder.mlmodelc directory - let apiPath = "tree/main/\(modelName)" - let dirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, apiPath) - - let (dirData, _) = try await DownloadUtils.fetchWithAuth(from: dirURL) - - guard let items = try JSONSerialization.jsonObject(with: dirData) as? [[String: Any]] else { - throw PocketTTSError.downloadFailed("Failed to list Mimi encoder files") - } - - // Collect all files recursively - var filesToDownload: [(path: String, size: Int)] = [] - - func collectFiles(from items: [[String: Any]], basePath: String) async throws { - for item in items { - guard let itemPath = item["path"] as? String, - let itemType = item["type"] as? String - else { continue } - - if itemType == "directory" { - let subDirURL = try ModelRegistry.apiModels(Repo.pocketTts.remotePath, "tree/main/\(itemPath)") - let (subDirData, _) = try await DownloadUtils.fetchWithAuth(from: subDirURL) - if let subItems = try JSONSerialization.jsonObject(with: subDirData) as? [[String: Any]] { - try await collectFiles(from: subItems, basePath: itemPath) - } - } else if itemType == "file" { - let fileSize = item["size"] as? Int ?? -1 - filesToDownload.append((path: itemPath, size: fileSize)) - } - } - } - - try await collectFiles(from: items, basePath: modelName) - logger.info("Found \(filesToDownload.count) files in Mimi encoder") - - // Download each file - for (index, file) in filesToDownload.enumerated() { - // Local path relative to modelName - let relativePath = - file.path.hasPrefix("\(modelName)/") - ? String(file.path.dropFirst(modelName.count + 1)) - : file.path - let destPath = modelDir.appendingPathComponent(relativePath) - - if FileManager.default.fileExists(atPath: destPath.path) { - continue - } - - try FileManager.default.createDirectory( - at: destPath.deletingLastPathComponent(), - withIntermediateDirectories: true - ) - - // Handle empty files - if file.size == 0 { - FileManager.default.createFile(atPath: destPath.path, contents: Data()) - continue - } - - let encodedPath = file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path - let fileURL = try ModelRegistry.resolveModel(Repo.pocketTts.remotePath, encodedPath) - - let (tempURL, response) = try await DownloadUtils.sharedSession.download( - for: URLRequest(url: fileURL, timeoutInterval: 1800) - ) - - guard let httpResponse = response as? HTTPURLResponse, - (200..<300).contains(httpResponse.statusCode) - else { - throw PocketTTSError.downloadFailed("Failed to download \(file.path)") - } - - if FileManager.default.fileExists(atPath: destPath.path) { - try? FileManager.default.removeItem(at: destPath) - } - try FileManager.default.moveItem(at: tempURL, to: destPath) - - if (index + 1) % 5 == 0 || index == filesToDownload.count - 1 { - logger.info("Downloaded \(index + 1)/\(filesToDownload.count) Mimi encoder files") - } - } - - logger.info("Mimi encoder download complete") + try await DownloadUtils.downloadSubdirectory( + .pocketTts, + subdirectory: ModelNames.PocketTTS.mimiEncoderFile, + to: repoDir + ) } /// Ensure constants (binary blobs + tokenizer) are available. diff --git a/Sources/FluidAudioCLI/Commands/DiarizationBenchmark.swift b/Sources/FluidAudioCLI/Commands/DiarizationBenchmark.swift index b081b995..a28b7ab8 100644 --- a/Sources/FluidAudioCLI/Commands/DiarizationBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/DiarizationBenchmark.swift @@ -180,9 +180,7 @@ enum StreamDiarizationBenchmark { printUsage() return default: - if !arguments[i].starts(with: "--") { - logger.warning("Unknown argument: \(arguments[i])") - } + logger.warning("Unknown argument: \(arguments[i])") } i += 1 } diff --git a/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift b/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift new file mode 100644 index 00000000..c5fdcdae --- /dev/null +++ b/Sources/FluidAudioCLI/Commands/DiarizationBenchmarkUtils.swift @@ -0,0 +1,332 @@ +#if os(macOS) +import FluidAudio +import Foundation + +/// Shared utilities for diarization benchmark commands (LS-EEND and Sortformer). +enum DiarizationBenchmarkUtils { + + /// Dataset corpora supported by diarization benchmarks. + enum Dataset: String { + case ami = "ami" + case voxconverse = "voxconverse" + case callhome = "callhome" + } + + /// Per-meeting benchmark result shared across diarization benchmark commands. + struct BenchmarkResult { + let meetingName: String + let der: Float + let missRate: Float + let falseAlarmRate: Float + let speakerErrorRate: Float + let rtfx: Float + let processingTime: Double + let totalFrames: Int + let detectedSpeakers: Int + let groundTruthSpeakers: Int + let modelLoadTime: Double + let audioLoadTime: Double? + } + + // MARK: - File Paths + + static func getAMIFiles(maxFiles: Int?) -> [String] { + let allMeetings = [ + "EN2002a", "EN2002b", "EN2002c", "EN2002d", + "ES2004a", "ES2004b", "ES2004c", "ES2004d", + "IS1009a", "IS1009b", "IS1009c", "IS1009d", + "TS3003a", "TS3003b", "TS3003c", "TS3003d", + ] + + var availableMeetings: [String] = [] + for meeting in allMeetings { + let path = getAudioPath(for: meeting, dataset: .ami) + if FileManager.default.fileExists(atPath: path) { + availableMeetings.append(meeting) + } + } + + if let max = maxFiles { + return Array(availableMeetings.prefix(max)) + } + return availableMeetings + } + + static func getAudioPath(for meeting: String, dataset: Dataset) -> String { + let homeDir = FileManager.default.homeDirectoryForCurrentUser + switch dataset { + case .ami: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav" + ).path + case .voxconverse: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav" + ).path + case .callhome: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/callhome_eng/\(meeting).wav" + ).path + } + } + + static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? { + let homeDir = FileManager.default.homeDirectoryForCurrentUser + switch dataset { + case .ami: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/ami_official/rttm/\(meeting).rttm" + ) + case .voxconverse: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm" + ) + case .callhome: + return homeDir.appendingPathComponent( + "FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm" + ) + } + } + + static func getVoxConverseFiles(maxFiles: Int?) -> [String] { + let homeDir = FileManager.default.homeDirectoryForCurrentUser + let voxDir = homeDir.appendingPathComponent( + "FluidAudioDatasets/voxconverse/voxconverse_test_wav" + ) + + guard + let files = try? FileManager.default.contentsOfDirectory( + at: voxDir, + includingPropertiesForKeys: nil + ) + else { + return [] + } + + var availableMeetings: [String] = [] + for file in files where file.pathExtension == "wav" { + let name = file.deletingPathExtension().lastPathComponent + let rttmPath = homeDir.appendingPathComponent( + "FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm" + ) + if FileManager.default.fileExists(atPath: rttmPath.path) { + availableMeetings.append(name) + } + } + + availableMeetings.sort() + if let max = maxFiles { + return Array(availableMeetings.prefix(max)) + } + return availableMeetings + } + + static func getCALLHOMEFiles(maxFiles: Int?) -> [String] { + let homeDir = FileManager.default.homeDirectoryForCurrentUser + let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng") + + guard + let files = try? FileManager.default.contentsOfDirectory( + at: callhomeDir, + includingPropertiesForKeys: nil + ) + else { + return [] + } + + var availableMeetings: [String] = [] + for file in files where file.pathExtension == "wav" { + let name = file.deletingPathExtension().lastPathComponent + let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm") + if FileManager.default.fileExists(atPath: rttmPath.path) { + availableMeetings.append(name) + } + } + + availableMeetings.sort() + if let max = maxFiles { + return Array(availableMeetings.prefix(max)) + } + return availableMeetings + } + + /// Returns files for the given dataset, filtering by availability. + static func getFiles(for dataset: Dataset, maxFiles: Int?) -> [String] { + switch dataset { + case .ami: + return getAMIFiles(maxFiles: maxFiles) + case .voxconverse: + return getVoxConverseFiles(maxFiles: maxFiles) + case .callhome: + return getCALLHOMEFiles(maxFiles: maxFiles) + } + } + + // MARK: - Summary & Output + + /// Prints a formatted benchmark summary table. + /// + /// - Parameters: + /// - results: Benchmark results to summarize. + /// - title: Header title (e.g. "LS-EEND BENCHMARK SUMMARY"). + /// - derTargets: DER percentage thresholds to check, ordered from strictest to most lenient + /// (e.g. `[15, 25]` prints "DER < 15%" if met, else "DER < 25%", else "DER > 25%"). + static func printFinalSummary( + results: [BenchmarkResult], + title: String, + derTargets: [Float] + ) { + guard !results.isEmpty else { return } + + print("\n" + String(repeating: "=", count: 80)) + print(title) + print(String(repeating: "=", count: 80)) + + print("Results Sorted by DER:") + print(String(repeating: "-", count: 70)) + print("Meeting DER % Miss % FA % SE % Speakers RTFx") + print(String(repeating: "-", count: 70)) + + for result in results.sorted(by: { $0.der < $1.der }) { + let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)" + let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0) + let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0) + print( + String( + format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f", + meetingCol, + result.der, + result.missRate, + result.falseAlarmRate, + result.speakerErrorRate, + speakerCol, + result.rtfx)) + } + print(String(repeating: "-", count: 70)) + + let count = Float(results.count) + let avgDER = results.map { $0.der }.reduce(0, +) / count + let avgMiss = results.map { $0.missRate }.reduce(0, +) / count + let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count + let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count + let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count + + print( + String( + format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f", + avgDER, avgMiss, avgFA, avgSE, avgRTFx)) + print(String(repeating: "=", count: 70)) + + print("\nTarget Check:") + var matched = false + for target in derTargets.sorted() { + if avgDER < target { + print(" DER < \(String(format: "%.0f", target))% (achieved: \(String(format: "%.1f", avgDER))%)") + matched = true + break + } + } + if !matched, let highest = derTargets.max() { + print( + " DER > \(String(format: "%.0f", highest))% (achieved: \(String(format: "%.1f", avgDER))%)") + } + + if avgRTFx > 1 { + print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") + } else { + print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") + } + } + + static func saveJSONResults(results: [BenchmarkResult], to path: String) { + let jsonData = results.map { resultToDict($0) } + do { + let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) + try data.write(to: URL(fileURLWithPath: path)) + print("JSON results saved to: \(path)") + } catch { + print("Failed to save JSON: \(error)") + } + } + + // MARK: - Progress Save/Load + + static func resultToDict(_ result: BenchmarkResult) -> [String: Any] { + var dict: [String: Any] = [ + "meeting": result.meetingName, + "der": result.der, + "missRate": result.missRate, + "falseAlarmRate": result.falseAlarmRate, + "speakerErrorRate": result.speakerErrorRate, + "rtfx": result.rtfx, + "processingTime": result.processingTime, + "totalFrames": result.totalFrames, + "detectedSpeakers": result.detectedSpeakers, + "groundTruthSpeakers": result.groundTruthSpeakers, + "modelLoadTime": result.modelLoadTime, + ] + if let audioLoadTime = result.audioLoadTime { + dict["audioLoadTime"] = audioLoadTime + } + return dict + } + + static func saveProgress(results: [BenchmarkResult], to path: String) { + let jsonData = results.map { resultToDict($0) } + do { + let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) + try data.write(to: URL(fileURLWithPath: path)) + } catch { + print("Failed to save progress: \(error)") + } + } + + static func loadProgress(from path: String) -> [BenchmarkResult]? { + guard FileManager.default.fileExists(atPath: path) else { return nil } + + do { + let data = try Data(contentsOf: URL(fileURLWithPath: path)) + guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else { + return nil + } + + return jsonArray.compactMap { dict -> BenchmarkResult? in + guard let meeting = dict["meeting"] as? String, + let der = (dict["der"] as? NSNumber)?.floatValue, + let missRate = (dict["missRate"] as? NSNumber)?.floatValue, + let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue, + let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue, + let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue, + let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue, + let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue, + let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue, + let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue, + let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue + else { + return nil + } + + let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue + + return BenchmarkResult( + meetingName: meeting, + der: der, + missRate: missRate, + falseAlarmRate: falseAlarmRate, + speakerErrorRate: speakerErrorRate, + rtfx: rtfx, + processingTime: processingTime, + totalFrames: totalFrames, + detectedSpeakers: detectedSpeakers, + groundTruthSpeakers: groundTruthSpeakers, + modelLoadTime: modelLoadTime, + audioLoadTime: audioLoadTime + ) + } + } catch { + print("Failed to load progress: \(error)") + return nil + } + } +} +#endif diff --git a/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift b/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift index 9e3a5f13..a4984783 100644 --- a/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/LSEENDBenchmark.swift @@ -1,5 +1,4 @@ #if os(macOS) -import AVFoundation import FluidAudio import Foundation @@ -7,26 +6,8 @@ import Foundation enum LSEENDBenchmark { private static let logger = AppLogger(category: "LSEENDBench") - enum Dataset: String { - case ami = "ami" - case voxconverse = "voxconverse" - case callhome = "callhome" - } - - struct BenchmarkResult { - let meetingName: String - let der: Float - let missRate: Float - let falseAlarmRate: Float - let speakerErrorRate: Float - let rtfx: Float - let processingTime: Double - let totalFrames: Int - let detectedSpeakers: Int - let groundTruthSpeakers: Int - let modelLoadTime: Double - let audioLoadTime: Double - } + typealias Dataset = DiarizationBenchmarkUtils.Dataset + typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult static func printUsage() { print( @@ -197,9 +178,7 @@ enum LSEENDBenchmark { printUsage() return default: - if !arguments[i].starts(with: "--") { - logger.warning("Unknown argument: \(arguments[i])") - } + logger.warning("Unknown argument: \(arguments[i])") } i += 1 } @@ -228,14 +207,7 @@ enum LSEENDBenchmark { if let meeting = singleFile { filesToProcess = [meeting] } else { - switch dataset { - case .ami: - filesToProcess = getAMIFiles(maxFiles: maxFiles) - case .voxconverse: - filesToProcess = getVoxConverseFiles(maxFiles: maxFiles) - case .callhome: - filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles) - } + filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles) } if filesToProcess.isEmpty { @@ -252,7 +224,7 @@ enum LSEENDBenchmark { var completedResults: [BenchmarkResult] = [] var completedMeetings: Set = [] if resumeFromProgress { - if let loaded = loadProgress(from: progressFile) { + if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) { completedResults = loaded completedMeetings = Set(loaded.map { $0.meetingName }) print("Resuming: loaded \(completedResults.count) previous results") @@ -339,7 +311,7 @@ enum LSEENDBenchmark { print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth") // Save progress after each file - saveProgress(results: allResults, to: progressFile) + DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile) print("Progress saved (\(allResults.count) files complete)") } fflush(stdout) @@ -349,11 +321,15 @@ enum LSEENDBenchmark { } // Print final summary - printFinalSummary(results: allResults) + DiarizationBenchmarkUtils.printFinalSummary( + results: allResults, + title: "LS-EEND BENCHMARK SUMMARY", + derTargets: [15, 25] + ) // Save results if let outputPath = outputFile { - saveJSONResults(results: allResults, to: outputPath) + DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath) } } @@ -369,7 +345,7 @@ enum LSEENDBenchmark { numSpeakers: Int, verbose: Bool ) async -> BenchmarkResult? { - let audioPath = getAudioPath(for: meetingName, dataset: dataset) + let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset) guard FileManager.default.fileExists(atPath: audioPath) else { print("Audio file not found: \(audioPath)") fflush(stdout) @@ -378,10 +354,7 @@ enum LSEENDBenchmark { do { // Load and process audio - let startLoadTime = Date() let audioURL = URL(fileURLWithPath: audioPath) - let audioLoadTime = Date().timeIntervalSince(startLoadTime) - let startTime = Date() let timeline = try diarizer.processComplete(audioFileURL: audioURL) let processingTime = Date().timeIntervalSince(startTime) @@ -400,7 +373,7 @@ enum LSEENDBenchmark { let rttmEntries: [LSEENDRTTMEntry] let rttmSpeakers: [String] - let rttmURL = getRTTMURL(for: meetingName, dataset: dataset) + let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset) if let rttmURL = rttmURL, FileManager.default.fileExists(atPath: rttmURL.path) { let parsed = try LSEENDEvaluation.parseRTTM(url: rttmURL) rttmEntries = parsed.entries @@ -413,7 +386,7 @@ enum LSEENDBenchmark { duration: duration ) guard !groundTruth.isEmpty else { - print("⚠️ No ground truth found for \(meetingName)") + print("No ground truth found for \(meetingName)") return nil } // Convert TimedSpeakerSegment to LSEENDRTTMEntry @@ -507,7 +480,7 @@ enum LSEENDBenchmark { detectedSpeakers: detectedSpeakerIndices.count, groundTruthSpeakers: rttmSpeakers.count, modelLoadTime: modelLoadTime, - audioLoadTime: audioLoadTime + audioLoadTime: nil ) } catch { @@ -516,279 +489,5 @@ enum LSEENDBenchmark { } } - // MARK: - File Paths - - private static func getAMIFiles(maxFiles: Int?) -> [String] { - let allMeetings = [ - "EN2002a", "EN2002b", "EN2002c", "EN2002d", - "ES2004a", "ES2004b", "ES2004c", "ES2004d", - "IS1009a", "IS1009b", "IS1009c", "IS1009d", - "TS3003a", "TS3003b", "TS3003c", "TS3003d", - ] - - var availableMeetings: [String] = [] - for meeting in allMeetings { - let path = getAudioPath(for: meeting, dataset: .ami) - if FileManager.default.fileExists(atPath: path) { - availableMeetings.append(meeting) - } - } - - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - return availableMeetings - } - - private static func getAudioPath(for meeting: String, dataset: Dataset) -> String { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - switch dataset { - case .ami: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav" - ).path - case .voxconverse: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav" - ).path - case .callhome: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/callhome_eng/\(meeting).wav" - ).path - } - } - - private static func getRTTMURL(for meeting: String, dataset: Dataset) -> URL? { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - switch dataset { - case .ami: - // Try local RTTM first, then fall back to dataset directory - let localPath = "Streaming-Sortformer-Conversion/\(meeting).rttm" - if FileManager.default.fileExists(atPath: localPath) { - return URL(fileURLWithPath: localPath) - } - let datasetPath = homeDir.appendingPathComponent( - "FluidAudioDatasets/ami_official/rttm/\(meeting).rttm" - ) - return datasetPath - case .voxconverse: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/rttm_repo/test/\(meeting).rttm" - ) - case .callhome: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/callhome_eng/rttm/\(meeting).rttm" - ) - } - } - - private static func getVoxConverseFiles(maxFiles: Int?) -> [String] { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - let voxDir = homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/voxconverse_test_wav" - ) - - guard - let files = try? FileManager.default.contentsOfDirectory( - at: voxDir, - includingPropertiesForKeys: nil - ) - else { - return [] - } - - var availableMeetings: [String] = [] - for file in files where file.pathExtension == "wav" { - let name = file.deletingPathExtension().lastPathComponent - let rttmPath = homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm" - ) - if FileManager.default.fileExists(atPath: rttmPath.path) { - availableMeetings.append(name) - } - } - - availableMeetings.sort() - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - return availableMeetings - } - - private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng") - - guard - let files = try? FileManager.default.contentsOfDirectory( - at: callhomeDir, - includingPropertiesForKeys: nil - ) - else { - return [] - } - - var availableMeetings: [String] = [] - for file in files where file.pathExtension == "wav" { - let name = file.deletingPathExtension().lastPathComponent - let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm") - if FileManager.default.fileExists(atPath: rttmPath.path) { - availableMeetings.append(name) - } - } - - availableMeetings.sort() - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - return availableMeetings - } - - // MARK: - Summary & Output - - private static func printFinalSummary(results: [BenchmarkResult]) { - guard !results.isEmpty else { return } - - print("\n" + String(repeating: "=", count: 80)) - print("LS-EEND BENCHMARK SUMMARY") - print(String(repeating: "=", count: 80)) - - print("Results Sorted by DER:") - print(String(repeating: "-", count: 70)) - print("Meeting DER % Miss % FA % SE % Speakers RTFx") - print(String(repeating: "-", count: 70)) - - for result in results.sorted(by: { $0.der < $1.der }) { - let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)" - let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0) - let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0) - print( - String( - format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f", - meetingCol, - result.der, - result.missRate, - result.falseAlarmRate, - result.speakerErrorRate, - speakerCol, - result.rtfx)) - } - print(String(repeating: "-", count: 70)) - - let count = Float(results.count) - let avgDER = results.map { $0.der }.reduce(0, +) / count - let avgMiss = results.map { $0.missRate }.reduce(0, +) / count - let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count - let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count - let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count - - print( - String( - format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f", - avgDER, avgMiss, avgFA, avgSE, avgRTFx)) - print(String(repeating: "=", count: 70)) - - print("\nTarget Check:") - if avgDER < 15 { - print(" DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)") - } else if avgDER < 25 { - print(" DER < 25% (achieved: \(String(format: "%.1f", avgDER))%)") - } else { - print(" DER > 25% (achieved: \(String(format: "%.1f", avgDER))%)") - } - - if avgRTFx > 1 { - print(" RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") - } else { - print(" RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") - } - } - - private static func saveJSONResults(results: [BenchmarkResult], to path: String) { - let jsonData = results.map { resultToDict($0) } - do { - let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) - try data.write(to: URL(fileURLWithPath: path)) - print("JSON results saved to: \(path)") - } catch { - print("Failed to save JSON: \(error)") - } - } - - // MARK: - Progress Save/Load - - private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] { - return [ - "meeting": result.meetingName, - "der": result.der, - "missRate": result.missRate, - "falseAlarmRate": result.falseAlarmRate, - "speakerErrorRate": result.speakerErrorRate, - "rtfx": result.rtfx, - "processingTime": result.processingTime, - "totalFrames": result.totalFrames, - "detectedSpeakers": result.detectedSpeakers, - "groundTruthSpeakers": result.groundTruthSpeakers, - "modelLoadTime": result.modelLoadTime, - "audioLoadTime": result.audioLoadTime, - ] - } - - private static func saveProgress(results: [BenchmarkResult], to path: String) { - let jsonData = results.map { resultToDict($0) } - do { - let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) - try data.write(to: URL(fileURLWithPath: path)) - } catch { - print("Failed to save progress: \(error)") - } - } - - private static func loadProgress(from path: String) -> [BenchmarkResult]? { - guard FileManager.default.fileExists(atPath: path) else { return nil } - - do { - let data = try Data(contentsOf: URL(fileURLWithPath: path)) - guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else { - return nil - } - - return jsonArray.compactMap { dict -> BenchmarkResult? in - guard let meeting = dict["meeting"] as? String, - let der = (dict["der"] as? NSNumber)?.floatValue, - let missRate = (dict["missRate"] as? NSNumber)?.floatValue, - let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue, - let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue, - let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue, - let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue, - let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue, - let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue, - let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue, - let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue, - let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue - else { - return nil - } - - return BenchmarkResult( - meetingName: meeting, - der: der, - missRate: missRate, - falseAlarmRate: falseAlarmRate, - speakerErrorRate: speakerErrorRate, - rtfx: rtfx, - processingTime: processingTime, - totalFrames: totalFrames, - detectedSpeakers: detectedSpeakers, - groundTruthSpeakers: groundTruthSpeakers, - modelLoadTime: modelLoadTime, - audioLoadTime: audioLoadTime - ) - } - } catch { - print("Failed to load progress: \(error)") - return nil - } - } } #endif diff --git a/Sources/FluidAudioCLI/Commands/LSEENDCommand.swift b/Sources/FluidAudioCLI/Commands/LSEENDCommand.swift index 6edf094a..1de39c0a 100644 --- a/Sources/FluidAudioCLI/Commands/LSEENDCommand.swift +++ b/Sources/FluidAudioCLI/Commands/LSEENDCommand.swift @@ -1,5 +1,4 @@ #if os(macOS) -import AVFoundation import FluidAudio import Foundation @@ -20,8 +19,6 @@ enum LSEENDCommand { var outputFile: String? var variant: LSEENDVariant = .dihard3 var threshold: Float = 0.5 - var medianWidth: Int = 1 - var collarSeconds: Double = 0.25 // Post-processing parameters var onset: Float? @@ -62,16 +59,6 @@ enum LSEENDCommand { threshold = v i += 1 } - case "--median-width": - if i + 1 < arguments.count, let v = Int(arguments[i + 1]) { - medianWidth = v - i += 1 - } - case "--collar": - if i + 1 < arguments.count, let v = Double(arguments[i + 1]) { - collarSeconds = v - i += 1 - } case "--onset": if i + 1 < arguments.count, let v = Float(arguments[i + 1]) { onset = v @@ -244,7 +231,7 @@ enum LSEENDCommand { } private static func printUsage() { - logger.info( + print( """ LS-EEND Command Usage: @@ -253,8 +240,6 @@ enum LSEENDCommand { Options: --variant Model variant: ami, callhome, dihard2, dihard3 (default: dihard3) --threshold Speaker activity threshold (default: 0.5) - --median-width Median filter width for post-processing (default: 1) - --collar Collar duration in seconds for evaluation (default: 0.25) --onset Onset threshold for speech detection (default: 0.5) --offset Offset threshold for speech detection (default: 0.5) --pad-onset Padding before speech segments in seconds @@ -273,8 +258,7 @@ enum LSEENDCommand { # Save results to file fluidaudio lseend audio.wav --output results.json - """ - ) + """) } } #endif diff --git a/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift b/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift index 23388e24..7d107992 100644 --- a/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift +++ b/Sources/FluidAudioCLI/Commands/LSEENDEvaluation.swift @@ -85,7 +85,7 @@ public enum LSEENDEvaluation { var speakers: [String] = [] for line in text.split(whereSeparator: \.isNewline) { let parts = line.split(separator: " ") - guard parts.count >= 8 else { continue } + guard parts.count >= 8, parts[0] == "SPEAKER" else { continue } let speaker = String(parts[7]) if !speakers.contains(speaker) { speakers.append(speaker) @@ -444,6 +444,7 @@ public enum LSEENDEvaluation { } private static func solveAssignmentRowsToColumns(cost: [Float], rows: Int, columns: Int) -> [Int] { + precondition(columns <= 20, "Assignment solver is O(2^columns); columns=\(columns) is too large") let stateCount = 1 << columns var dp = [Float](repeating: .greatestFiniteMagnitude, count: stateCount) var parent = [Int](repeating: -1, count: stateCount) diff --git a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift index ee3b9037..50e59ba4 100644 --- a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift @@ -1,5 +1,4 @@ #if os(macOS) -import AVFoundation import FluidAudio import Foundation @@ -7,26 +6,8 @@ import Foundation enum SortformerBenchmark { private static let logger = AppLogger(category: "SortformerBench") - enum Dataset: String { - case ami = "ami" - case voxconverse = "voxconverse" - case callhome = "callhome" - } - - struct BenchmarkResult { - let meetingName: String - let der: Float - let missRate: Float - let falseAlarmRate: Float - let speakerErrorRate: Float - let rtfx: Float - let processingTime: Double - let totalFrames: Int - let detectedSpeakers: Int - let groundTruthSpeakers: Int - let modelLoadTime: Double - let audioLoadTime: Double - } + typealias Dataset = DiarizationBenchmarkUtils.Dataset + typealias BenchmarkResult = DiarizationBenchmarkUtils.BenchmarkResult static func printUsage() { print( @@ -42,7 +23,6 @@ enum SortformerBenchmark { --single-file Process a specific meeting (e.g., ES2004a) --max-files Maximum number of files to process --threshold Speaker activity threshold (default: 0.5) - --preprocessor Path to SortformerPreprocessor.mlpackage --model Path to Sortformer.mlpackage --nvidia-low-latency Use NVIDIA 1.04s latency config (20.57% DER target) --nvidia-high-latency Use NVIDIA 30.4s latency config (20.57% DER target) @@ -102,7 +82,7 @@ enum SortformerBenchmark { if let d = Dataset(rawValue: arguments[i + 1].lowercased()) { dataset = d } else { - print("⚠️ Unknown dataset: \(arguments[i + 1]). Using ami.") + print("Unknown dataset: \(arguments[i + 1]). Using ami.") } i += 1 } @@ -158,9 +138,7 @@ enum SortformerBenchmark { printUsage() return default: - if !arguments[i].starts(with: "--") { - logger.warning("Unknown argument: \(arguments[i])") - } + logger.warning("Unknown argument: \(arguments[i])") } i += 1 } @@ -170,7 +148,7 @@ enum SortformerBenchmark { useHuggingFace = true } - print("🚀 Starting Sortformer Benchmark") + print("Starting Sortformer Benchmark") fflush(stdout) print(" Dataset: \(dataset.rawValue)") print(" Threshold: \(threshold)") @@ -209,7 +187,7 @@ enum SortformerBenchmark { // Download dataset if needed if autoDownload && dataset == .ami { - print("📥 Downloading AMI dataset if needed...") + print("Downloading AMI dataset if needed...") await DatasetDownloader.downloadAMIDataset( variant: .sdm, force: false, @@ -223,23 +201,16 @@ enum SortformerBenchmark { if let meeting = singleFile { filesToProcess = [meeting] } else { - switch dataset { - case .ami: - filesToProcess = getAMIFiles(maxFiles: maxFiles) - case .voxconverse: - filesToProcess = getVoxConverseFiles(maxFiles: maxFiles) - case .callhome: - filesToProcess = getCALLHOMEFiles(maxFiles: maxFiles) - } + filesToProcess = DiarizationBenchmarkUtils.getFiles(for: dataset, maxFiles: maxFiles) } if filesToProcess.isEmpty { - print("❌ No files found to process") + print("No files found to process") fflush(stdout) return } - print("📂 Processing \(filesToProcess.count) file(s)") + print("Processing \(filesToProcess.count) file(s)") print(" Progress file: \(progressFile)") fflush(stdout) @@ -247,29 +218,29 @@ enum SortformerBenchmark { var completedResults: [BenchmarkResult] = [] var completedMeetings: Set = [] if resumeFromProgress { - if let loaded = loadProgress(from: progressFile) { + if let loaded = DiarizationBenchmarkUtils.loadProgress(from: progressFile) { completedResults = loaded completedMeetings = Set(loaded.map { $0.meetingName }) - print("📥 Resuming: loaded \(completedResults.count) previous results") + print("Resuming: loaded \(completedResults.count) previous results") for result in completedResults { - print(" ✅ \(result.meetingName): \(String(format: "%.1f", result.der))% DER") + print(" \(result.meetingName): \(String(format: "%.1f", result.der))% DER") } } else { - print("📥 No previous progress found, starting fresh") + print("No previous progress found, starting fresh") } } print("") fflush(stdout) // Initialize Sortformer - print("🔧 Loading Sortformer models...") + print("Loading Sortformer models...") fflush(stdout) let modelLoadStart = Date() var config: SortformerConfig if useNvidiaHighLatency { - config = SortformerConfig.nvidiaHighLatencyV2_1 + config = SortformerConfig.highContextV2_1 } else if useNvidiaLowLatency { - config = SortformerConfig.nvidiaLowLatencyV2_1 + config = SortformerConfig.balancedV2_1 } else { config = SortformerConfig.default } @@ -287,12 +258,12 @@ enum SortformerBenchmark { ) } } catch { - print("❌ Failed to initialize Sortformer: \(error)") + print("Failed to initialize Sortformer: \(error)") return } let modelLoadTime = Date().timeIntervalSince(modelLoadStart) - print("✅ Models loaded in \(String(format: "%.2f", modelLoadTime))s\n") + print("Models loaded in \(String(format: "%.2f", modelLoadTime))s\n") fflush(stdout) // Process each file @@ -324,14 +295,14 @@ enum SortformerBenchmark { allResults.append(result) // Print summary - print("📊 Results for \(meetingName):") + print("Results for \(meetingName):") print(" DER: \(String(format: "%.1f", result.der))%") print(" RTFx: \(String(format: "%.1f", result.rtfx))x") print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth") // Save progress after each file - saveProgress(results: allResults, to: progressFile) - print("💾 Progress saved (\(allResults.count) files complete)") + DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile) + print("Progress saved (\(allResults.count) files complete)") } fflush(stdout) @@ -340,11 +311,15 @@ enum SortformerBenchmark { } // Print final summary - printFinalSummary(results: allResults) + DiarizationBenchmarkUtils.printFinalSummary( + results: allResults, + title: "SORTFORMER BENCHMARK SUMMARY", + derTargets: [15, 20] + ) // Save results if let outputPath = outputFile { - saveJSONResults(results: allResults, to: outputPath) + DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath) } } @@ -357,9 +332,9 @@ enum SortformerBenchmark { verbose: Bool ) async -> BenchmarkResult? { - let audioPath = getAudioPath(for: meetingName, dataset: dataset) + let audioPath = DiarizationBenchmarkUtils.getAudioPath(for: meetingName, dataset: dataset) guard FileManager.default.fileExists(atPath: audioPath) else { - print("❌ Audio file not found: \(audioPath)") + print("Audio file not found: \(audioPath)") fflush(stdout) return nil } @@ -441,7 +416,7 @@ enum SortformerBenchmark { } guard !groundTruth.isEmpty else { - print("⚠️ No ground truth found for \(meetingName)") + print("No ground truth found for \(meetingName)") return nil } @@ -453,7 +428,7 @@ enum SortformerBenchmark { let simpleMetrics = calculateSimpleDER( predictions: filteredPredictions, numFrames: result.numFinalizedFrames, - numSpeakers: 4, + numSpeakers: result.config.numSpeakers, groundTruth: groundTruth, threshold: threshold, frameShift: 0.08 // 80ms frames like NeMo @@ -490,267 +465,7 @@ enum SortformerBenchmark { ) } catch { - print("❌ Error processing \(meetingName): \(error)") - return nil - } - } - - private static func getAMIFiles(maxFiles: Int?) -> [String] { - // Official AMI SDM test set (16 meetings) - matches NeMo evaluation - let allMeetings = [ - "EN2002a", "EN2002b", "EN2002c", "EN2002d", - "ES2004a", "ES2004b", "ES2004c", "ES2004d", - "IS1009a", "IS1009b", "IS1009c", "IS1009d", - "TS3003a", "TS3003b", "TS3003c", "TS3003d", - ] - - var availableMeetings: [String] = [] - for meeting in allMeetings { - let path = getAudioPath(for: meeting, dataset: .ami) - if FileManager.default.fileExists(atPath: path) { - availableMeetings.append(meeting) - } - } - - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - - return availableMeetings - } - - private static func getAudioPath(for meeting: String, dataset: Dataset) -> String { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - switch dataset { - case .ami: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/ami_official/sdm/\(meeting).Mix-Headset.wav" - ).path - case .voxconverse: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/voxconverse_test_wav/\(meeting).wav" - ).path - case .callhome: - return homeDir.appendingPathComponent( - "FluidAudioDatasets/callhome_eng/\(meeting).wav" - ).path - } - } - - private static func getVoxConverseFiles(maxFiles: Int?) -> [String] { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - let voxDir = homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/voxconverse_test_wav" - ) - - guard - let files = try? FileManager.default.contentsOfDirectory( - at: voxDir, - includingPropertiesForKeys: nil - ) - else { - return [] - } - - var availableMeetings: [String] = [] - for file in files where file.pathExtension == "wav" { - let name = file.deletingPathExtension().lastPathComponent - // Check that RTTM file exists - let rttmPath = homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/rttm_repo/test/\(name).rttm" - ) - if FileManager.default.fileExists(atPath: rttmPath.path) { - availableMeetings.append(name) - } - } - - // Sort alphabetically for reproducibility - availableMeetings.sort() - - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - - return availableMeetings - } - - private static func getCALLHOMEFiles(maxFiles: Int?) -> [String] { - let homeDir = FileManager.default.homeDirectoryForCurrentUser - let callhomeDir = homeDir.appendingPathComponent("FluidAudioDatasets/callhome_eng") - - guard - let files = try? FileManager.default.contentsOfDirectory( - at: callhomeDir, - includingPropertiesForKeys: nil - ) - else { - return [] - } - - var availableMeetings: [String] = [] - for file in files where file.pathExtension == "wav" { - let name = file.deletingPathExtension().lastPathComponent - // Check that RTTM file exists - let rttmPath = callhomeDir.appendingPathComponent("rttm/\(name).rttm") - if FileManager.default.fileExists(atPath: rttmPath.path) { - availableMeetings.append(name) - } - } - - // Sort alphabetically for reproducibility - availableMeetings.sort() - - if let max = maxFiles { - return Array(availableMeetings.prefix(max)) - } - - return availableMeetings - } - - private static func printFinalSummary(results: [BenchmarkResult]) { - guard !results.isEmpty else { return } - - print("\n" + String(repeating: "=", count: 80)) - print("SORTFORMER BENCHMARK SUMMARY") - print(String(repeating: "=", count: 80)) - - print("📋 Results Sorted by DER:") - print(String(repeating: "-", count: 70)) - print("Meeting DER % Miss % FA % SE % Speakers RTFx") - print(String(repeating: "-", count: 70)) - - for result in results.sorted(by: { $0.der < $1.der }) { - let speakerInfo = "\(result.detectedSpeakers)/\(result.groundTruthSpeakers)" - let meetingCol = result.meetingName.padding(toLength: 12, withPad: " ", startingAt: 0) - let speakerCol = speakerInfo.padding(toLength: 10, withPad: " ", startingAt: 0) - print( - String( - format: "%@ %8.1f %8.1f %8.1f %8.1f %@ %8.1f", - meetingCol, - result.der, - result.missRate, - result.falseAlarmRate, - result.speakerErrorRate, - speakerCol, - result.rtfx)) - } - print(String(repeating: "-", count: 70)) - - let count = Float(results.count) - let avgDER = results.map { $0.der }.reduce(0, +) / count - let avgMiss = results.map { $0.missRate }.reduce(0, +) / count - let avgFA = results.map { $0.falseAlarmRate }.reduce(0, +) / count - let avgSE = results.map { $0.speakerErrorRate }.reduce(0, +) / count - let avgRTFx = results.map { $0.rtfx }.reduce(0, +) / count - - print( - String( - format: "AVERAGE %8.1f %8.1f %8.1f %8.1f - %8.1f", - avgDER, avgMiss, avgFA, avgSE, avgRTFx)) - print(String(repeating: "=", count: 70)) - - print("\n✅ Target Check:") - if avgDER < 15 { - print(" ✅ DER < 15% (achieved: \(String(format: "%.1f", avgDER))%)") - } else if avgDER < 20 { - print(" 🟡 DER < 20% (achieved: \(String(format: "%.1f", avgDER))%)") - } else { - print(" ❌ DER > 20% (achieved: \(String(format: "%.1f", avgDER))%)") - } - - if avgRTFx > 1 { - print(" ✅ RTFx > 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") - } else { - print(" ❌ RTFx < 1x (achieved: \(String(format: "%.1f", avgRTFx))x)") - } - } - - private static func saveJSONResults(results: [BenchmarkResult], to path: String) { - let jsonData = results.map { result in - resultToDict(result) - } - - do { - let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) - try data.write(to: URL(fileURLWithPath: path)) - print("💾 JSON results saved to: \(path)") - } catch { - print("❌ Failed to save JSON: \(error)") - } - } - - // MARK: - Progress Save/Load - - private static func resultToDict(_ result: BenchmarkResult) -> [String: Any] { - return [ - "meeting": result.meetingName, - "der": result.der, - "missRate": result.missRate, - "falseAlarmRate": result.falseAlarmRate, - "speakerErrorRate": result.speakerErrorRate, - "rtfx": result.rtfx, - "processingTime": result.processingTime, - "totalFrames": result.totalFrames, - "detectedSpeakers": result.detectedSpeakers, - "groundTruthSpeakers": result.groundTruthSpeakers, - "modelLoadTime": result.modelLoadTime, - "audioLoadTime": result.audioLoadTime, - ] - } - - private static func saveProgress(results: [BenchmarkResult], to path: String) { - let jsonData = results.map { resultToDict($0) } - do { - let data = try JSONSerialization.data(withJSONObject: jsonData, options: .prettyPrinted) - try data.write(to: URL(fileURLWithPath: path)) - } catch { - print("⚠️ Failed to save progress: \(error)") - } - } - - private static func loadProgress(from path: String) -> [BenchmarkResult]? { - guard FileManager.default.fileExists(atPath: path) else { return nil } - - do { - let data = try Data(contentsOf: URL(fileURLWithPath: path)) - guard let jsonArray = try JSONSerialization.jsonObject(with: data) as? [[String: Any]] else { - return nil - } - - return jsonArray.compactMap { dict -> BenchmarkResult? in - guard let meeting = dict["meeting"] as? String, - let der = (dict["der"] as? NSNumber)?.floatValue, - let missRate = (dict["missRate"] as? NSNumber)?.floatValue, - let falseAlarmRate = (dict["falseAlarmRate"] as? NSNumber)?.floatValue, - let speakerErrorRate = (dict["speakerErrorRate"] as? NSNumber)?.floatValue, - let rtfx = (dict["rtfx"] as? NSNumber)?.floatValue, - let processingTime = (dict["processingTime"] as? NSNumber)?.doubleValue, - let totalFrames = (dict["totalFrames"] as? NSNumber)?.intValue, - let detectedSpeakers = (dict["detectedSpeakers"] as? NSNumber)?.intValue, - let groundTruthSpeakers = (dict["groundTruthSpeakers"] as? NSNumber)?.intValue, - let modelLoadTime = (dict["modelLoadTime"] as? NSNumber)?.doubleValue, - let audioLoadTime = (dict["audioLoadTime"] as? NSNumber)?.doubleValue - else { - return nil - } - - return BenchmarkResult( - meetingName: meeting, - der: der, - missRate: missRate, - falseAlarmRate: falseAlarmRate, - speakerErrorRate: speakerErrorRate, - rtfx: rtfx, - processingTime: processingTime, - totalFrames: totalFrames, - detectedSpeakers: detectedSpeakers, - groundTruthSpeakers: groundTruthSpeakers, - modelLoadTime: modelLoadTime, - audioLoadTime: audioLoadTime - ) - } - } catch { - print("⚠️ Failed to load progress: \(error)") + print("Error processing \(meetingName): \(error)") return nil } } @@ -760,23 +475,11 @@ enum SortformerBenchmark { /// Load ground truth from RTTM file like Python does /// Format: SPEAKER 1 private static func loadRTTMGroundTruth(for meetingName: String, dataset: Dataset) -> [TimedSpeakerSegment] { - // Determine RTTM path based on dataset - let rttmPath: String - let homeDir = FileManager.default.homeDirectoryForCurrentUser - switch dataset { - case .ami: - rttmPath = "Streaming-Sortformer-Conversion/\(meetingName).rttm" - case .voxconverse: - rttmPath = - homeDir.appendingPathComponent( - "FluidAudioDatasets/voxconverse/rttm_repo/test/\(meetingName).rttm" - ).path - case .callhome: - rttmPath = - homeDir.appendingPathComponent( - "FluidAudioDatasets/callhome_eng/rttm/\(meetingName).rttm" - ).path + guard let rttmURL = DiarizationBenchmarkUtils.getRTTMURL(for: meetingName, dataset: dataset) else { + print(" [RTTM] No RTTM URL for \(meetingName)") + return [] } + let rttmPath = rttmURL.path guard FileManager.default.fileExists(atPath: rttmPath) else { print(" [RTTM] File not found: \(rttmPath)") diff --git a/Sources/FluidAudioCLI/Commands/SortformerCommand.swift b/Sources/FluidAudioCLI/Commands/SortformerCommand.swift index 0ef52e78..af412891 100644 --- a/Sources/FluidAudioCLI/Commands/SortformerCommand.swift +++ b/Sources/FluidAudioCLI/Commands/SortformerCommand.swift @@ -1,5 +1,4 @@ #if os(macOS) -import AVFoundation import FluidAudio import Foundation @@ -131,9 +130,10 @@ enum SortformerCommand { print( "[DEBUG] First 10 audio samples: \((0.. 0.5 { + let idx = frame * numSpeakers + spk + if idx < predictions.count, predictions[idx] > 0.5 { speakerActivity[spk] += result.config.frameDurationSeconds } } @@ -233,7 +234,7 @@ enum SortformerCommand { } private static func printUsage() { - logger.info( + print( """ Sortformer Command Usage: @@ -259,8 +260,7 @@ enum SortformerCommand { # Save results to file fluidaudio sortformer audio.wav --output results.json - """ - ) + """) } } #endif diff --git a/Tests/FluidAudioTests/ASR/Streaming/NeMoMelSpectrogramTests.swift b/Tests/FluidAudioTests/ASR/Streaming/AudioMelSpectrogramTests.swift similarity index 96% rename from Tests/FluidAudioTests/ASR/Streaming/NeMoMelSpectrogramTests.swift rename to Tests/FluidAudioTests/ASR/Streaming/AudioMelSpectrogramTests.swift index 89445e5d..74b3499c 100644 --- a/Tests/FluidAudioTests/ASR/Streaming/NeMoMelSpectrogramTests.swift +++ b/Tests/FluidAudioTests/ASR/Streaming/AudioMelSpectrogramTests.swift @@ -3,13 +3,13 @@ import XCTest @testable import FluidAudio -final class NeMoMelSpectrogramTests: XCTestCase { +final class AudioMelSpectrogramTests: XCTestCase { - private var mel: NeMoMelSpectrogram! + private var mel: AudioMelSpectrogram! override func setUp() { super.setUp() - mel = NeMoMelSpectrogram() + mel = AudioMelSpectrogram() } override func tearDown() { diff --git a/Tests/FluidAudioTests/Diarizer/DiarizationTestFixtures.swift b/Tests/FluidAudioTests/Diarizer/DiarizationTestFixtures.swift new file mode 100644 index 00000000..60538a1a --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/DiarizationTestFixtures.swift @@ -0,0 +1,136 @@ +import AVFoundation +import Foundation +import XCTest + +@testable import FluidAudio + +/// Shared fixture infrastructure for diarization tests. +/// +/// Generates a deterministic multi-segment waveform with silence gaps, writes it to a +/// temporary WAV file, and caches it for reuse across tests within the same process. +enum DiarizationTestFixtures { + static let fixtureSampleRate = 16_000 + + nonisolated(unsafe) private static var cachedFixtureAudioURL: URL? + + /// Returns a cached URL to the fixture WAV file, creating it on first access. + static func fixtureAudioFileURL() throws -> URL { + if let cached = cachedFixtureAudioURL, + FileManager.default.fileExists(atPath: cached.path) + { + return cached + } + + let url = FileManager.default.temporaryDirectory + .appendingPathComponent("diarization-fixture-\(UUID().uuidString)") + .appendingPathExtension("wav") + try writeFixtureAudio(to: url) + cachedFixtureAudioURL = url + return url + } + + /// Loads fixture audio resampled to the given sample rate, optionally limited to a duration. + static func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] { + let converter = AudioConverter(sampleRate: Double(sampleRate)) + let audio = try converter.resampleAudioFile(try fixtureAudioFileURL()) + guard let limitSeconds else { + return audio + } + let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate))) + return Array(audio.prefix(sampleCount)) + } + + /// Loads a slice of fixture audio at the given sample rate. + static func fixtureAudio( + sampleRate: Int, startSeconds: Double, durationSeconds: Double + ) throws -> [Float] { + let converter = AudioConverter(sampleRate: Double(sampleRate)) + let audio = try converter.resampleAudioFile(try fixtureAudioFileURL()) + let startSample = min(audio.count, Int(startSeconds * Double(sampleRate))) + let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate))) + return Array(audio[startSample.. [[Float]] { + var chunks: [[Float]] = [] + var start = 0 + var index = 0 + while start < samples.count { + let size = sizes[index % sizes.count] + let stop = min(samples.count, start + size) + chunks.append(Array(samples[start.. [Float] { + let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [ + (1.0, 0.20, 220), + (0.35, 0.00, 0), + (1.1, 0.32, 330), + (0.25, 0.00, 0), + (1.0, 0.28, 180), + (0.40, 0.00, 0), + (1.3, 0.36, 260), + (0.30, 0.00, 0), + (1.1, 0.24, 410), + ] + + var output: [Float] = [] + for (duration, amplitude, frequency) in segments { + let frameCount = Int(duration * sampleRate) + guard amplitude > 0, frequency > 0 else { + output.append(contentsOf: repeatElement(0, count: frameCount)) + continue + } + + for frame in 0.. LSEENDInferenceHelper { @@ -259,115 +257,10 @@ final class LSEENDIntegrationTests: XCTestCase { return engine } - private func fixtureAudio(sampleRate: Int, limitSeconds: Double? = nil) throws -> [Float] { - let converter = AudioConverter(sampleRate: Double(sampleRate)) - let audio = try converter.resampleAudioFile(try fixtureAudioFileURL()) - guard let limitSeconds else { - return audio - } - let sampleCount = min(audio.count, Int(limitSeconds * Double(sampleRate))) - return Array(audio.prefix(sampleCount)) - } - - private func fixtureAudioFileURL() throws -> URL { - if let cached = Self.cachedFixtureAudioURL, - FileManager.default.fileExists(atPath: cached.path) - { - return cached - } - - let url = FileManager.default.temporaryDirectory - .appendingPathComponent("lseend-fixture-\(UUID().uuidString)") - .appendingPathExtension("wav") - try writeFixtureAudio(to: url) - Self.cachedFixtureAudioURL = url - return url - } - - private func writeFixtureAudio(to url: URL) throws { - let sampleRate = Double(Self.fixtureSampleRate) - let samples = makeFixtureSamples(sampleRate: sampleRate) - let format = AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: sampleRate, - channels: 1, - interleaved: false - )! - guard - let buffer = AVAudioPCMBuffer( - pcmFormat: format, - frameCapacity: AVAudioFrameCount(samples.count) - ) - else { - XCTFail("Failed to allocate fixture audio buffer") - return - } - - buffer.frameLength = AVAudioFrameCount(samples.count) - samples.withUnsafeBufferPointer { source in - guard let destination = buffer.floatChannelData?[0] else { return } - destination.update(from: source.baseAddress!, count: samples.count) - } - - let file = try AVAudioFile( - forWriting: url, - settings: format.settings, - commonFormat: .pcmFormatFloat32, - interleaved: false - ) - try file.write(from: buffer) - } - - private func makeFixtureSamples(sampleRate: Double) -> [Float] { - let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [ - (1.0, 0.20, 220), - (0.35, 0.00, 0), - (1.1, 0.32, 330), - (0.25, 0.00, 0), - (1.0, 0.28, 180), - (0.40, 0.00, 0), - (1.3, 0.36, 260), - (0.30, 0.00, 0), - (1.1, 0.24, 410), - ] - - var output: [Float] = [] - for (duration, amplitude, frequency) in segments { - let frameCount = Int(duration * sampleRate) - guard amplitude > 0, frequency > 0 else { - output.append(contentsOf: repeatElement(0, count: frameCount)) - continue - } - - for frame in 0.. Double { Double(samples.count) / Double(sampleRate) } - private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] { - var chunks: [[Float]] = [] - var start = 0 - var index = 0 - while start < samples.count { - let size = sizes[index % sizes.count] - let stop = min(samples.count, start + size) - chunks.append(Array(samples[start.. Bool { - let mirror = Mirror(reflecting: diarizer) - guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else { - XCTFail("Expected LS-EEND diarizer to expose _session via reflection") - return false - } - - let optionalMirror = Mirror(reflecting: sessionValue) - XCTAssertEqual(optionalMirror.displayStyle, .optional) - return optionalMirror.children.count == 1 - } } diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift new file mode 100644 index 00000000..1667413e --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDMatrixTests.swift @@ -0,0 +1,326 @@ +import XCTest + +@testable import FluidAudio + +final class LSEENDMatrixTests: XCTestCase { + + // MARK: - Init (validated) + + func testInitWithMatchingDimensions() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) + XCTAssertEqual(m.rows, 2) + XCTAssertEqual(m.columns, 3) + XCTAssertEqual(m.values, [1, 2, 3, 4, 5, 6]) + } + + func testInitThrowsOnCountMismatch() { + XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3])) { error in + guard case LSEENDError.invalidMatrixShape = error else { + return XCTFail("Expected invalidMatrixShape, got \(error)") + } + } + } + + func testInitThrowsOnNegativeRows() { + XCTAssertThrowsError(try LSEENDMatrix(rows: -1, columns: 3, values: [])) { error in + guard case LSEENDError.invalidMatrixShape = error else { + return XCTFail("Expected invalidMatrixShape, got \(error)") + } + } + } + + func testInitThrowsOnNegativeColumns() { + XCTAssertThrowsError(try LSEENDMatrix(rows: 2, columns: -1, values: [])) { error in + guard case LSEENDError.invalidMatrixShape = error else { + return XCTFail("Expected invalidMatrixShape, got \(error)") + } + } + } + + func testInitWithZeroDimensions() throws { + let m = try LSEENDMatrix(rows: 0, columns: 5, values: []) + XCTAssertEqual(m.rows, 0) + XCTAssertEqual(m.columns, 5) + XCTAssertTrue(m.isEmpty) + } + + // MARK: - Factory Methods + + func testZeros() { + let m = LSEENDMatrix.zeros(rows: 3, columns: 2) + XCTAssertEqual(m.rows, 3) + XCTAssertEqual(m.columns, 2) + XCTAssertEqual(m.values, [Float](repeating: 0, count: 6)) + } + + func testEmpty() { + let m = LSEENDMatrix.empty(columns: 4) + XCTAssertEqual(m.rows, 0) + XCTAssertEqual(m.columns, 4) + XCTAssertTrue(m.isEmpty) + } + + // MARK: - isEmpty + + func testIsEmptyZeroRows() { + XCTAssertTrue(LSEENDMatrix.empty(columns: 3).isEmpty) + } + + func testIsEmptyZeroColumns() { + let m = LSEENDMatrix(validatingRows: 3, columns: 0, values: []) + XCTAssertTrue(m.isEmpty) + } + + func testIsEmptyFalseForPopulatedMatrix() throws { + let m = try LSEENDMatrix(rows: 1, columns: 1, values: [42]) + XCTAssertFalse(m.isEmpty) + } + + // MARK: - Subscript + + func testSubscriptGet() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [10, 20, 30, 40, 50, 60]) + XCTAssertEqual(m[0, 0], 10) + XCTAssertEqual(m[0, 2], 30) + XCTAssertEqual(m[1, 0], 40) + XCTAssertEqual(m[1, 2], 60) + } + + func testSubscriptSet() throws { + var m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + m[1, 0] = 99 + XCTAssertEqual(m[1, 0], 99) + XCTAssertEqual(m.values, [1, 2, 99, 4]) + } + + // MARK: - row() + + func testRow() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + XCTAssertEqual(Array(m.row(0)), [1, 2]) + XCTAssertEqual(Array(m.row(1)), [3, 4]) + XCTAssertEqual(Array(m.row(2)), [5, 6]) + } + + // MARK: - prefixingColumns + + func testPrefixingColumns() throws { + let m = try LSEENDMatrix(rows: 2, columns: 4, values: [1, 2, 3, 4, 5, 6, 7, 8]) + let prefix = m.prefixingColumns(2) + XCTAssertEqual(prefix.rows, 2) + XCTAssertEqual(prefix.columns, 2) + XCTAssertEqual(prefix.values, [1, 2, 5, 6]) + } + + func testPrefixingColumnsEqualToWidth() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) + let same = m.prefixingColumns(3) + XCTAssertEqual(same, m) + } + + func testPrefixingColumnsGreaterThanWidth() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) + let same = m.prefixingColumns(10) + XCTAssertEqual(same, m) + } + + func testPrefixingColumnsZero() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) + let empty = m.prefixingColumns(0) + XCTAssertTrue(empty.isEmpty) + } + + func testPrefixingColumnsOnEmptyMatrix() { + let m = LSEENDMatrix.empty(columns: 4) + let result = m.prefixingColumns(2) + XCTAssertTrue(result.isEmpty) + XCTAssertEqual(result.columns, 2) + } + + // MARK: - rowMajorRows + + func testRowMajorRows() throws { + let m = try LSEENDMatrix(rows: 2, columns: 3, values: [1, 2, 3, 4, 5, 6]) + let rows = m.rowMajorRows() + XCTAssertEqual(rows, [[1, 2, 3], [4, 5, 6]]) + } + + func testRowMajorRowsEmpty() { + let m = LSEENDMatrix.empty(columns: 3) + XCTAssertEqual(m.rowMajorRows(), []) + } + + // MARK: - appendingRows + + func testAppendingRows() throws { + let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let b = try LSEENDMatrix(rows: 1, columns: 2, values: [5, 6]) + let result = a.appendingRows(b) + XCTAssertEqual(result.rows, 3) + XCTAssertEqual(result.columns, 2) + XCTAssertEqual(result.values, [1, 2, 3, 4, 5, 6]) + } + + func testAppendingRowsToEmpty() throws { + let a = LSEENDMatrix.empty(columns: 2) + let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + XCTAssertEqual(a.appendingRows(b), b) + } + + func testAppendingEmptyRows() throws { + let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let b = LSEENDMatrix.empty(columns: 2) + XCTAssertEqual(a.appendingRows(b), a) + } + + // MARK: - droppingFirstRows + + func testDroppingFirstRows() throws { + let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8]) + let dropped = m.droppingFirstRows(2) + XCTAssertEqual(dropped.rows, 2) + XCTAssertEqual(dropped.columns, 2) + XCTAssertEqual(dropped.values, [5, 6, 7, 8]) + } + + func testDroppingAllRows() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let dropped = m.droppingFirstRows(3) + XCTAssertEqual(dropped.rows, 0) + XCTAssertTrue(dropped.isEmpty) + } + + func testDroppingMoreThanTotalRows() throws { + let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let dropped = m.droppingFirstRows(100) + XCTAssertEqual(dropped.rows, 0) + XCTAssertTrue(dropped.isEmpty) + } + + func testDroppingZeroRows() throws { + let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let same = m.droppingFirstRows(0) + XCTAssertEqual(same, m) + } + + func testDroppingNegativeCount() throws { + let m = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let same = m.droppingFirstRows(-5) + XCTAssertEqual(same, m) + } + + // MARK: - slicingRows + + func testSlicingRows() throws { + let m = try LSEENDMatrix(rows: 5, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + let slice = m.slicingRows(start: 1, end: 4) + XCTAssertEqual(slice.rows, 3) + XCTAssertEqual(slice.values, [3, 4, 5, 6, 7, 8]) + } + + func testSlicingRowsFullRange() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let slice = m.slicingRows(start: 0, end: 3) + XCTAssertEqual(slice, m) + } + + func testSlicingRowsEmptyRange() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let slice = m.slicingRows(start: 2, end: 2) + XCTAssertTrue(slice.isEmpty) + XCTAssertEqual(slice.columns, 2) + } + + func testSlicingRowsClampsOutOfBounds() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let slice = m.slicingRows(start: -5, end: 100) + XCTAssertEqual(slice, m) + } + + func testSlicingRowsInvertedRange() throws { + let m = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let slice = m.slicingRows(start: 3, end: 1) + XCTAssertTrue(slice.isEmpty) + } + + // MARK: - applyingSigmoid + + func testSigmoidZero() throws { + let m = try LSEENDMatrix(rows: 1, columns: 1, values: [0]) + let s = m.applyingSigmoid() + XCTAssertEqual(s[0, 0], 0.5, accuracy: 1e-6) + } + + func testSigmoidLargePositive() throws { + let m = try LSEENDMatrix(rows: 1, columns: 1, values: [20]) + let s = m.applyingSigmoid() + XCTAssertEqual(s[0, 0], 1.0, accuracy: 1e-5) + } + + func testSigmoidLargeNegative() throws { + let m = try LSEENDMatrix(rows: 1, columns: 1, values: [-20]) + let s = m.applyingSigmoid() + XCTAssertEqual(s[0, 0], 0.0, accuracy: 1e-5) + } + + func testSigmoidPreservesShape() throws { + let m = try LSEENDMatrix(rows: 3, columns: 4, values: [Float](repeating: 0, count: 12)) + let s = m.applyingSigmoid() + XCTAssertEqual(s.rows, 3) + XCTAssertEqual(s.columns, 4) + XCTAssertEqual(s.values.count, 12) + } + + func testSigmoidDoesNotMutateOriginal() throws { + let m = try LSEENDMatrix(rows: 1, columns: 2, values: [0, 0]) + _ = m.applyingSigmoid() + XCTAssertEqual(m.values, [0, 0]) + } + + func testSigmoidOnEmpty() { + let m = LSEENDMatrix.empty(columns: 3) + let s = m.applyingSigmoid() + XCTAssertTrue(s.isEmpty) + } + + // MARK: - Equatable + + func testEqualMatrices() throws { + let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + XCTAssertEqual(a, b) + } + + func testUnequalValues() throws { + let a = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 4]) + let b = try LSEENDMatrix(rows: 2, columns: 2, values: [1, 2, 3, 5]) + XCTAssertNotEqual(a, b) + } + + // MARK: - Roundtrip: append then drop + + func testAppendThenDropRecoversOriginal() throws { + let original = try LSEENDMatrix(rows: 3, columns: 2, values: [1, 2, 3, 4, 5, 6]) + let extra = try LSEENDMatrix(rows: 2, columns: 2, values: [7, 8, 9, 10]) + let combined = original.appendingRows(extra) + let recovered = combined.slicingRows(start: 0, end: 3) + XCTAssertEqual(recovered, original) + } + + func testSliceThenAppendRecombines() throws { + let m = try LSEENDMatrix(rows: 4, columns: 2, values: [1, 2, 3, 4, 5, 6, 7, 8]) + let head = m.slicingRows(start: 0, end: 2) + let tail = m.slicingRows(start: 2, end: 4) + let recombined = head.appendingRows(tail) + XCTAssertEqual(recombined, m) + } + + func testDropThenPrefixColumnsCommutes() throws { + let m = try LSEENDMatrix(rows: 4, columns: 4, values: (0..<16).map { Float($0) }) + + let dropFirst = m.droppingFirstRows(2).prefixingColumns(2) + let prefixFirst = m.prefixingColumns(2).droppingFirstRows(2) + + XCTAssertEqual(dropFirst, prefixFirst) + } +} diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTests.swift index d971ef2d..81d58824 100644 --- a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTests.swift @@ -12,7 +12,7 @@ final class SortformerTests: XCTestCase { // Create 5 seconds of deterministic random audio let sampleRate = 16000 let audioCount = sampleRate * 5 - srand48(Int(Date().timeIntervalSince1970 * 1e6)) + srand48(42) let audio = (0.. SortformerModels { @@ -163,11 +160,13 @@ final class SpeakerEnrollmentTests: XCTestCase { let diarizer = SortformerDiarizer(config: config) let models = try await loadSortformerModelsForTest(config: config) diarizer.initialize(models: models) - let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") - XCTAssertNotNil(speaker) + try XCTSkipIf( + speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") XCTAssertEqual(speaker?.name, "Alice") XCTAssertEqual(diarizer.numFramesProcessed, 0) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0) @@ -184,14 +183,17 @@ final class SpeakerEnrollmentTests: XCTestCase { let diarizer = SortformerDiarizer(config: config) let models = try await loadSortformerModelsForTest(config: config) diarizer.initialize(models: models) - let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) - let liveAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0) + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) + let liveAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 3.0) let speaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") - XCTAssertNotNil(speaker) + try XCTSkipIf( + speaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") var update: DiarizerTimelineUpdate? - for chunk in chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) { + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [7_680, 9_600, 11_520]) { diarizer.addAudio(chunk) if let next = try diarizer.process() { update = next @@ -217,17 +219,21 @@ final class SpeakerEnrollmentTests: XCTestCase { let diarizer = SortformerDiarizer(config: config) let models = try await loadSortformerModelsForTest(config: config) diarizer.initialize(models: models) - let speakerAAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0) - let speakerBAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0) + let speakerAAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 3.0) + let speakerBAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 3.4, durationSeconds: 3.0) let speakerA = try diarizer.enrollSpeaker(withAudio: speakerAAudio, named: "Alice") - XCTAssertNotNil(speakerA) + try XCTSkipIf( + speakerA == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") let stateAfterA = diarizer.state let cachedLengthAfterA = stateAfterA.spkcacheLength + stateAfterA.fifoLength let speakerB = try diarizer.enrollSpeaker(withAudio: speakerBAudio, named: "Bob") - XCTAssertNotNil(speakerB) + try XCTSkipIf( + speakerB == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") let stateAfterB = diarizer.state XCTAssertGreaterThanOrEqual( @@ -250,9 +256,12 @@ final class SpeakerEnrollmentTests: XCTestCase { let diarizer = SortformerDiarizer(config: config) let models = try await loadSortformerModelsForTest(config: config) diarizer.initialize(models: models) - let enrollmentAudio = try fixtureAudio(sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) let firstSpeaker = try diarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") + try XCTSkipIf( + firstSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") let secondSpeaker = try diarizer.enrollSpeaker( withAudio: enrollmentAudio, named: "Bob", @@ -287,7 +296,7 @@ final class SpeakerEnrollmentTests: XCTestCase { let engine = try await loadLseendEngineForTest() let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) diarizer.initialize(engine: engine) - let enrollmentAudio = try fixtureAudio( + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") @@ -298,7 +307,7 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(diarizer.numFramesProcessed, 0) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0) XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 }) - XCTAssertTrue(hasActiveLseendSession(diarizer)) + XCTAssertTrue(diarizer.hasActiveSession) } func testLseendEnrollSpeakerFollowedByStreamingProcessingStartsAtFrameZero() async throws { @@ -307,14 +316,15 @@ final class SpeakerEnrollmentTests: XCTestCase { let engine = try await loadLseendEngineForTest() let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) diarizer.initialize(engine: engine) - let enrollmentAudio = try fixtureAudio( + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) - let liveAudio = try fixtureAudio(sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0) + let liveAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0) let speaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") var firstUpdate: DiarizerTimelineUpdate? - for chunk in chunk(liveAudio, sizes: [977, 1231, 1607]) { + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [977, 1231, 1607]) { if let update = try diarizer.process(samples: chunk) { firstUpdate = update break @@ -338,9 +348,9 @@ final class SpeakerEnrollmentTests: XCTestCase { let engine = try await loadLseendEngineForTest() let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) diarizer.initialize(engine: engine) - let speakerAAudio = try fixtureAudio( + let speakerAAudio = try DiarizationTestFixtures.fixtureAudio( sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) - let speakerBAudio = try fixtureAudio( + let speakerBAudio = try DiarizationTestFixtures.fixtureAudio( sampleRate: engine.targetSampleRate, startSeconds: 3.0, durationSeconds: 3.0) let speakerA = try diarizer.enrollSpeaker(withSamples: speakerAAudio, named: "Alice") @@ -348,7 +358,7 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(diarizer.numFramesProcessed, 0) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, 0) - XCTAssertTrue(hasActiveLseendSession(diarizer)) + XCTAssertTrue(diarizer.hasActiveSession) let expectedNames = Set([speakerA?.name, speakerB?.name].compactMap { $0 }) XCTAssertEqual(Set(namedSpeakerNames(in: diarizer.timeline)), expectedNames) } @@ -359,11 +369,12 @@ final class SpeakerEnrollmentTests: XCTestCase { let engine = try await loadLseendEngineForTest() let diarizer = LSEENDDiarizer(computeUnits: .cpuOnly) diarizer.initialize(engine: engine) - let enrollmentAudio = try fixtureAudio( + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( sampleRate: engine.targetSampleRate, startSeconds: 0.0, durationSeconds: 3.0) let firstSpeaker = try diarizer.enrollSpeaker(withSamples: enrollmentAudio, named: "Alice") - try XCTSkipIf(firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.") + try XCTSkipIf( + firstSpeaker == nil, "Fixture did not produce a confident LS-EEND speaker segment on this host.") let secondSpeaker = try diarizer.enrollSpeaker( withSamples: enrollmentAudio, named: "Bob", @@ -375,95 +386,6 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(namedSpeakerNames(in: diarizer.timeline), ["Alice"]) } - private func fixtureAudio(sampleRate: Int, startSeconds: Double = 0.0, durationSeconds: Double) throws -> [Float] { - let converter = AudioConverter(sampleRate: Double(sampleRate)) - let audio = try converter.resampleAudioFile(try fixtureAudioFileURL()) - let startSample = min(audio.count, Int(startSeconds * Double(sampleRate))) - let endSample = min(audio.count, startSample + Int(durationSeconds * Double(sampleRate))) - return Array(audio[startSample.. URL { - if let cached = Self.cachedFixtureAudioURL, - FileManager.default.fileExists(atPath: cached.path) - { - return cached - } - - let url = FileManager.default.temporaryDirectory - .appendingPathComponent("speaker-enrollment-fixture-\(UUID().uuidString)") - .appendingPathExtension("wav") - try writeFixtureAudio(to: url) - Self.cachedFixtureAudioURL = url - return url - } - - private func writeFixtureAudio(to url: URL) throws { - let sampleRate = Double(Self.fixtureSampleRate) - let samples = makeFixtureSamples(sampleRate: sampleRate) - let format = AVAudioFormat( - commonFormat: .pcmFormatFloat32, - sampleRate: sampleRate, - channels: 1, - interleaved: false - )! - guard - let buffer = AVAudioPCMBuffer( - pcmFormat: format, - frameCapacity: AVAudioFrameCount(samples.count) - ) - else { - XCTFail("Failed to allocate fixture audio buffer") - return - } - - buffer.frameLength = AVAudioFrameCount(samples.count) - samples.withUnsafeBufferPointer { source in - guard let destination = buffer.floatChannelData?[0] else { return } - destination.update(from: source.baseAddress!, count: samples.count) - } - - let file = try AVAudioFile( - forWriting: url, - settings: format.settings, - commonFormat: .pcmFormatFloat32, - interleaved: false - ) - try file.write(from: buffer) - } - - private func makeFixtureSamples(sampleRate: Double) -> [Float] { - let segments: [(duration: Double, amplitude: Float, frequency: Double)] = [ - (1.0, 0.20, 220), - (0.35, 0.00, 0), - (1.1, 0.32, 330), - (0.25, 0.00, 0), - (1.0, 0.28, 180), - (0.40, 0.00, 0), - (1.3, 0.36, 260), - (0.30, 0.00, 0), - (1.1, 0.24, 410), - ] - - var output: [Float] = [] - for (duration, amplitude, frequency) in segments { - let frameCount = Int(duration * sampleRate) - guard amplitude > 0, frequency > 0 else { - output.append(contentsOf: repeatElement(0, count: frameCount)) - continue - } - - for frame in 0.. [Int] { timeline.speakers.values .filter { $0.name != nil } @@ -476,29 +398,4 @@ final class SpeakerEnrollmentTests: XCTestCase { .compactMap(\.name) .sorted() } - - private func chunk(_ samples: [Float], sizes: [Int]) -> [[Float]] { - var chunks: [[Float]] = [] - var start = 0 - var index = 0 - while start < samples.count { - let size = sizes[index % sizes.count] - let stop = min(samples.count, start + size) - chunks.append(Array(samples[start.. Bool { - let mirror = Mirror(reflecting: diarizer) - guard let sessionValue = mirror.children.first(where: { $0.label == "_session" })?.value else { - XCTFail("Expected LS-EEND diarizer to expose _session via reflection") - return false - } - - let optionalMirror = Mirror(reflecting: sessionValue) - return optionalMirror.displayStyle == .optional && optionalMirror.children.count == 1 - } }