diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index ba84725a..bed3e912 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -680,7 +680,7 @@ public enum ModelNames { public static let flowlmStepV2 = "flowlm_stepv2" public static let flowDecoder = "flow_decoder" public static let mimiDecoder = "mimi_decoder" - public static let mimiEncoder = "mimi_encoder" + public static let mimiEncoder = "mimi_encoderv2" public static let condStepFile = condStep + ".mlmodelc" public static let flowlmStepFile = flowlmStep + ".mlmodelc" diff --git a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift index 51d2f2cf..c06a1367 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsConstantsLoader.swift @@ -5,6 +5,11 @@ public struct PocketTtsConstantsBundle: Sendable { public let bosEmbedding: [Float] public let textEmbedTable: [Float] public let tokenizer: SentencePieceTokenizer + /// `flow_lm.bos_before_voice` (1024 floats) — prepended to the v1 + /// audio_prompt during `cond_step` prefill. `nil` when the language + /// pack predates the FluidAudio #592 fix and omits the file; v2 + /// (snapshot) voices don't need it, so loading stays best-effort. + public let bosBeforeVoice: [Float]? } /// Pre-loaded voice conditioning data. @@ -100,12 +105,15 @@ public enum PocketTtsConstantsLoader { throw LoadError.tokenizerLoadFailed(error.localizedDescription) } + let bosBeforeVoice = try loadBosBeforeVoiceIfPresent(in: constantsDir) + logger.info("Loaded PocketTTS constants from \(directory.lastPathComponent)") return PocketTtsConstantsBundle( bosEmbedding: bosEmb, textEmbedTable: embedTable, - tokenizer: tokenizer + tokenizer: tokenizer, + bosBeforeVoice: bosBeforeVoice ) } @@ -259,6 +267,37 @@ public enum PocketTtsConstantsLoader { return PocketTtsVoiceCacheSnapshot(layers: layers, cacheSeqLen: seqLen) } + // MARK: - Internal helpers + + /// Load `bos_before_voice.bin` from `constantsDir` if it exists. + /// + /// `bos_before_voice.bin` ships with language packs updated for the + /// FluidAudio #592 fix (pocket-tts 2.0.0 `flow_lm.bos_before_voice`). + /// Older packs and snapshot-only callers don't need it, so a missing + /// file resolves to `nil` rather than throwing — the v1 prefill path + /// enforces presence at use time. + /// + /// Exposed at internal access for unit tests; production code goes + /// through `load(from:)`. + static func loadBosBeforeVoiceIfPresent(in constantsDir: URL) throws -> [Float]? { + let url = constantsDir.appendingPathComponent("bos_before_voice.bin") + guard FileManager.default.fileExists(atPath: url.path) else { + // Snapshot-voice users never need this file, so absence is the + // expected steady state for pre-#592 caches. Log at debug to + // avoid noise; the v1 cloned-voice path surfaces a clear error + // at `prefillKVCache` use time when it actually matters. + logger.debug( + "PocketTTS constants_bin/bos_before_voice.bin not present; cloned-voice v1 prefill will fail until the language pack is updated" + ) + return nil + } + return try loadFloatArray( + from: url, + expectedCount: PocketTtsConstants.embeddingDim, + name: "bos_before_voice" + ) + } + // MARK: - Private /// Load a raw Float32 binary file into a [Float] array. diff --git a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift index f49dcdb0..c7c86d37 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Assets/PocketTtsResourceDownloader.swift @@ -48,9 +48,27 @@ public enum PocketTtsResourceDownloader { atPath: languageRoot.appendingPathComponent(model).path) } - guard !allPresent else { + if allPresent { logger.info( "PocketTTS \(language.rawValue) (\(precision)) models found in cache") + // Pre-#592 caches lack `constants_bin/bos_before_voice.bin`. The + // language-pack files are otherwise complete, so try to fetch just + // the missing constant rather than re-downloading the whole subdir. + // + // Best-effort: shipped snapshot voices don't need this file at all, + // and the v1 cloned-voice prefill path enforces presence at use + // time (PocketTtsConstantsLoader returns nil gracefully). Failing + // the fetch here — e.g. offline, or before the file lands on HF — + // must not block users who only synthesize with shipped voices. + do { + try await ensureBosBeforeVoice(language: language, languageRoot: languageRoot) + } catch { + logger.warning( + "Failed to backfill bos_before_voice.bin for \(language.rawValue): " + + "\(error.localizedDescription). Cloned-voice v1 prefill will fail " + + "until this file is available; shipped snapshot voices are unaffected." + ) + } return languageRoot } @@ -71,6 +89,34 @@ public enum PocketTtsResourceDownloader { return languageRoot } + /// Backfill `constants_bin/bos_before_voice.bin` for cached language packs + /// that were downloaded before the FluidAudio #592 fix. New downloads pick + /// it up via `downloadSubdirectory` — this helper exists only to upgrade + /// older caches without a full re-download. + private static func ensureBosBeforeVoice( + language: PocketTtsLanguage, + languageRoot: URL + ) async throws { + let constantsDir = languageRoot.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir) + let bosURL = constantsDir.appendingPathComponent("bos_before_voice.bin") + if FileManager.default.fileExists(atPath: bosURL.path) { + return + } + try FileManager.default.createDirectory( + at: constantsDir, withIntermediateDirectories: true) + let remotePath = "\(language.repoSubdirectory)/constants_bin/bos_before_voice.bin" + let remoteURL = try ModelRegistry.resolveModel(Repo.pocketTts.remotePath, remotePath) + logger.info( + "Backfilling bos_before_voice.bin for cached \(language.rawValue) pack...") + let data = try await AssetDownloader.fetchData( + from: remoteURL, + description: "bos_before_voice.bin (\(language.rawValue))", + logger: logger + ) + try data.write(to: bosURL, options: [.atomic]) + logger.info("Wrote bos_before_voice.bin (\(data.count) bytes)") + } + /// Delete the FlowLM `.mlmodelc` and `.mlpackage` directories that don't /// match the requested precision. Idempotent — silently skips paths that /// don't exist. diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift index f6c5de75..4b27a6a9 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift @@ -121,11 +121,20 @@ extension PocketTtsSynthesizer { /// Prefill a KV cache state with voice conditioning tokens. /// - /// Processes all voice tokens from the voice data, writing K/V projections - /// into the cache starting at the current position. + /// Prepends a single `bos_before_voice` token to match pocket-tts 2.0.0's + /// `flow_lm.bos_before_voice` prefix (see FluidAudio #592 — without it + /// `cond_step` diverges from the deployed flowlm/flow_decoder weights and + /// the LM emits EOS within a few steps, producing garbled audio). Then + /// processes all voice tokens from `voiceData.audioPrompt`, writing K/V + /// projections into the cache starting at the current position. + /// + /// `bosBeforeVoice` must be provided whenever `voiceData.audioPrompt` + /// has content (i.e. cloned voices); shipped v2 voices skip this path + /// entirely via `cacheSnapshot`. static func prefillKVCacheVoice( state: KVCacheState, voiceData: PocketTtsVoiceData, + bosBeforeVoice: [Float]?, model: MLModel, layerKeys: PocketTtsLayerKeys ) async throws -> KVCacheState { @@ -133,6 +142,33 @@ extension PocketTtsSynthesizer { let dim = PocketTtsConstants.embeddingDim let voiceTokenCount = voiceData.promptLength + guard voiceTokenCount > 0 else { + // Nothing to prefill (e.g. session warmup with empty cloned + // voice). Skip the BOS prepend too — runtime callers that go + // through `prefillKVCache` only hit this branch when both + // `cacheSnapshot == nil` and `promptLength == 0`, which is a + // no-op. + return state + } + + guard let bosBeforeVoice else { + throw PocketTTSError.processingFailed( + "PocketTTS v1 cloned-voice prefill requires bos_before_voice constant. " + + "Re-download the language pack to get constants_bin/bos_before_voice.bin " + + "(added in the FluidAudio #592 fix)." + ) + } + guard bosBeforeVoice.count == dim else { + throw PocketTTSError.processingFailed( + "bos_before_voice has \(bosBeforeVoice.count) floats, expected \(dim)" + ) + } + + let bosToken = try createConditioningToken( + from: bosBeforeVoice, offset: 0, dim: dim) + try await runCondStep( + conditioning: bosToken, state: &state, model: model, layerKeys: layerKeys) + for tokenIdx in 0.. KVCacheState { @@ -268,7 +306,9 @@ extension PocketTtsSynthesizer { } else { let emptyState = try emptyKVCacheState(layers: layerKeys.layerCount) state = try await prefillKVCacheVoice( - state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys + state: emptyState, voiceData: voiceData, + bosBeforeVoice: bosBeforeVoice, + model: model, layerKeys: layerKeys ) } state = try await prefillKVCacheText( diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift index 7604da2c..f0549f0c 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer.swift @@ -273,8 +273,8 @@ public struct PocketTtsSynthesizer { // - Shipped voices (cacheSnapshot != nil): drop pre-baked K/V into // cache, skip cond_step entirely (`promptLength == 0`, so the // loop in `prefillKVCacheVoice` would be a no-op anyway). - // - Cloned voices (flat audio prompt): feed every voice token - // through cond_step. + // - Cloned voices (flat audio prompt): feed `bos_before_voice` + // plus every voice token through cond_step. let voiceKVSnapshot: KVCacheState if let snapshot = voiceData.cacheSnapshot { voiceKVSnapshot = try kvCacheStateFromSnapshot( @@ -282,8 +282,9 @@ public struct PocketTtsSynthesizer { } else { let emptyState = try emptyKVCacheState(layers: condLayerKeys.layerCount) voiceKVSnapshot = try await prefillKVCacheVoice( - state: emptyState, voiceData: voiceData, model: condModel, - layerKeys: condLayerKeys + state: emptyState, voiceData: voiceData, + bosBeforeVoice: constants.bosBeforeVoice, + model: condModel, layerKeys: condLayerKeys ) } @@ -438,6 +439,7 @@ public struct PocketTtsSynthesizer { var kvState = try await PocketTtsSynthesizer.prefillKVCache( voiceData: voiceData, textEmbeddings: textEmbeddings, + bosBeforeVoice: constants.bosBeforeVoice, model: condModel, layerKeys: condLayerKeys ) diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsVoiceCloner.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsVoiceCloner.swift index 817f8b88..dfd70189 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsVoiceCloner.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsVoiceCloner.swift @@ -1,3 +1,4 @@ +import Accelerate @preconcurrency import AVFoundation @preconcurrency import CoreML import Foundation @@ -19,14 +20,21 @@ public enum PocketTtsVoiceCloner { /// Frame size for the encoder (1920 samples = 80ms). public static let frameSize: Int = PocketTtsConstants.samplesPerFrame - /// Maximum voice prompt frames (caps at ~20s to leave KV cache room for text tokens). - public static let maxVoiceFrames: Int = 250 + /// Fixed encoder input length in samples (10s @ 24kHz). `mimi_encoderv2` has + /// `hasShapeFlexibility: "0"` and accepts exactly this many samples. + public static let encoderInputSamples: Int = 240_000 + + /// Maximum voice prompt frames produced by the encoder for one forward pass + /// (`encoderInputSamples / frameSize`). The encoder output shape is fixed at + /// `[1, 125, 1024]`, so 125 is the hard ceiling. + public static let maxVoiceFrames: Int = 125 /// Minimum audio duration in seconds for voice cloning. public static let minDurationSeconds: Double = 1.0 - /// Maximum audio duration in seconds for voice cloning. - public static let maxDurationSeconds: Double = 30.0 + /// Maximum audio duration in seconds for voice cloning (matches + /// `encoderInputSamples`). Audio longer than this is truncated. + public static let maxDurationSeconds: Double = 10.0 // MARK: - Voice Cloning @@ -49,22 +57,24 @@ public enum PocketTtsVoiceCloner { + "(minimum \(minDurationSeconds)s required)" ) } - guard durationSeconds <= maxDurationSeconds else { - throw PocketTTSError.processingFailed( - "Audio too long for voice cloning: \(String(format: "%.1f", durationSeconds))s " - + "(maximum \(maxDurationSeconds)s allowed)" - ) - } - // Pad audio to frame boundary - let paddedSamples = padToFrameBoundary(samples) + // mimi_encoderv2 has a fixed input shape [1, 1, 240000]. Pad shorter + // audio with zeros; truncate longer audio. Track the real sample count + // so we can drop encoded-zero-padding frames from the output. + let realSampleCount = min(samples.count, encoderInputSamples) + let encoderInput = makeEncoderInputBuffer(samples) - logger.info("Encoding \(paddedSamples.count) samples (\(String(format: "%.1f", durationSeconds))s)") + logger.info( + "Encoding \(realSampleCount) samples (\(String(format: "%.1f", durationSeconds))s) " + + "padded/truncated to \(encoderInputSamples)" + ) - // Create input tensor [1, 1, T] - let audioArray = try MLMultiArray(shape: [1, 1, NSNumber(value: paddedSamples.count)], dataType: .float32) - for (i, sample) in paddedSamples.enumerated() { - audioArray[[0, 0, NSNumber(value: i)]] = NSNumber(value: sample) + // Create input tensor [1, 1, 240000] + let audioArray = try MLMultiArray( + shape: [1, 1, NSNumber(value: encoderInputSamples)], dataType: .float32) + let dst = audioArray.dataPointer.bindMemory(to: Float.self, capacity: encoderInputSamples) + encoderInput.withUnsafeBufferPointer { src in + dst.update(from: src.baseAddress!, count: encoderInputSamples) } // Run encoder @@ -78,7 +88,8 @@ public enum PocketTtsVoiceCloner { let numFrames = conditioning.shape[1].intValue let embDim = conditioning.shape[2].intValue - let usableFrames = min(numFrames, maxVoiceFrames) + let usableFrames = usableFrameCount( + realSampleCount: realSampleCount, availableFrames: numFrames) logger.info("Encoded to \(numFrames) frames, using \(usableFrames)") // Extract conditioning with bulk memory copy (no zero-padding) @@ -168,28 +179,67 @@ public enum PocketTtsVoiceCloner { // MARK: - Private Helpers - private static func padToFrameBoundary(_ samples: [Float]) -> [Float] { - let length = samples.count - let padLength = (frameSize - (length % frameSize)) % frameSize - if padLength > 0 { - return samples + [Float](repeating: 0, count: padLength) + /// Build a fixed-length `encoderInputSamples`-sized buffer: copy the first + /// `encoderInputSamples` of `samples` (truncating overflow), zero-pad the + /// remainder. `mimi_encoderv2`'s input shape is non-flexible at runtime. + /// + /// Exposed at internal access for unit tests; production callers go + /// through `cloneVoice(from:using:)`. + static func makeEncoderInputBuffer(_ samples: [Float]) -> [Float] { + var buffer = [Float](repeating: 0, count: encoderInputSamples) + let copyCount = min(samples.count, encoderInputSamples) + if copyCount > 0 { + buffer.replaceSubrange(0.. Int { + let realFrames = (realSampleCount + frameSize - 1) / frameSize + return min(availableFrames, realFrames, maxVoiceFrames) + } + + /// Extract conditioning floats from MLMultiArray `[1, frames, embDim]`. + /// + /// Both dtype paths assume contiguous storage starting at the array's + /// base pointer: the encoder writes `[1, 125, 1024]` in row-major order + /// and we read the leading `frames` rows. The Float32 path is a bulk + /// `UnsafeBufferPointer` copy; the Float16 path uses + /// `vDSP.convertElements` (vectorized fp16→fp32 conversion) so + /// `mimi_encoderv2`'s Float16 output doesn't have to pay 128 k + /// MLMultiArray subscript calls per clone. Falls back to NSNumber + /// subscripting on x86 hosts where Swift `Float16` isn't available. private static func extractConditioning( _ conditioning: MLMultiArray, frames: Int, embDim: Int ) -> [Float] { let count = frames * embDim if conditioning.dataType == .float16 { - return (0.. Data in + Data(buffer: buffer) + } + let url = tmpDir.appendingPathComponent("bos_before_voice.bin") + try data.write(to: url) + + let loaded = try PocketTtsConstantsLoader.loadBosBeforeVoiceIfPresent(in: tmpDir) + XCTAssertNotNil(loaded) + XCTAssertEqual(loaded?.count, dim) + XCTAssertEqual(loaded ?? [], expected) + } + + func testBosBeforeVoiceThrowsOnWrongSize() throws { + // Truncated file (1023 floats instead of 1024) must be rejected, + // not silently zero-padded. + let bad: [Float] = Array(repeating: 0, count: PocketTtsConstants.embeddingDim - 1) + let data = bad.withUnsafeBufferPointer { buffer -> Data in + Data(buffer: buffer) + } + let url = tmpDir.appendingPathComponent("bos_before_voice.bin") + try data.write(to: url) + + XCTAssertThrowsError( + try PocketTtsConstantsLoader.loadBosBeforeVoiceIfPresent(in: tmpDir) + ) { error in + guard + let loadError = error as? PocketTtsConstantsLoader.LoadError, + case .invalidSize(let name, let expected, let actual) = loadError + else { + XCTFail("Expected LoadError.invalidSize, got \(error)") + return + } + XCTAssertEqual(name, "bos_before_voice") + XCTAssertEqual(expected, PocketTtsConstants.embeddingDim) + XCTAssertEqual(actual, PocketTtsConstants.embeddingDim - 1) + } + } +} diff --git a/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsVoiceClonerTests.swift b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsVoiceClonerTests.swift new file mode 100644 index 00000000..4e0681d5 --- /dev/null +++ b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsVoiceClonerTests.swift @@ -0,0 +1,107 @@ +import Foundation +import XCTest + +@testable import FluidAudio + +/// Pure-logic unit tests for `PocketTtsVoiceCloner`'s pad/truncate and +/// frame-trim helpers. The full `cloneVoice(from:using:)` entry point +/// needs an `MLModel`, so these tests drive the smaller internal +/// helpers (`makeEncoderInputBuffer`, `usableFrameCount`) which the +/// production path delegates to. +final class PocketTtsVoiceClonerTests: XCTestCase { + + // MARK: - makeEncoderInputBuffer + + func testEncoderInputBufferPadsShorterAudio() { + // 7.5 s of audio @ 24 kHz = 180_000 samples; encoder wants 240_000. + let realCount = 180_000 + let input = (0..