feat: add PocketTTS backend for lightweight text-to-speech (#273)

## Summary
- Add PocketTTS as a new TTS backend — flow-matching language model with
autoregressive streaming synthesis
- Pure Swift implementation using 4 CoreML models (cond_step,
flowlm_step, flow_decoder, mimi_decoder)
- iOS 17 compatible — no `scaled_dot_product_attention` ops (avoids BNNS
crash)
- Add audio post-processor with de-esser for reducing sibilant harshness

## Test plan
- [x] Short sentence: WER 0, 3.44s audio
- [x] Long sentence: WER 0, 6.64s audio
- [x] Fresh HuggingFace download works end-to-end
- [x] iOS build succeeds (`xcodebuild -destination
'generic/platform=iOS'`)
- [x] macOS build succeeds (`swift build -c release`)
This commit is contained in:
Alex
2026-02-02 23:49:47 -05:00
committed by GitHub
parent 980a7e5256
commit 9fcdf2f32c
24 changed files with 2319 additions and 137 deletions
+3 -1
View File
@@ -43,7 +43,8 @@ TDT models process audio in chunks (~15s with overlap) as batch operations. Fast
| Model | Description | Context |
|-------|-------------|---------|
| **Kokoro TTS** | Text-to-speech synthesis (82M params), 48 voices, minimal RAM usage on iOS. | First TTS backend added. |
| **Kokoro TTS** | Text-to-speech synthesis (82M params), 48 voices, minimal RAM usage on iOS. Generates all frames at once via flow matching over mel spectrograms + Vocos vocoder. Requires espeak for phonemization. | First TTS backend added. |
| **PocketTTS** | Second TTS backend (~155M params). Upgrade over Kokoro with much better dynamic audio chunking. No espeak dependency. | |
## Model Sources
@@ -58,3 +59,4 @@ TDT models process audio in chunks (~15s with overlap) as batch operations. Fast
| Diarization (Pyannote) | [FluidInference/speaker-diarization-coreml](https://huggingface.co/FluidInference/speaker-diarization-coreml) |
| Sortformer | [FluidInference/diar-streaming-sortformer-coreml](https://huggingface.co/FluidInference/diar-streaming-sortformer-coreml) |
| Kokoro TTS | [FluidInference/kokoro-82m-coreml](https://huggingface.co/FluidInference/kokoro-82m-coreml) |
| PocketTTS | [FluidInference/pocket-tts-coreml](https://huggingface.co/FluidInference/pocket-tts-coreml) |
+110
View File
@@ -0,0 +1,110 @@
# Kokoro: High-Quality Text-to-Speech
## Overview
Kokoro is a high-quality, English-only TTS backend. It generates the entire audio representation in one pass (all frames at once) using flow matching over mel spectrograms, then converts to audio with the Vocos vocoder.
## Quick Start
### CLI
```bash
swift run fluidaudio tts "Welcome to FluidAudio text to speech" \
--output ~/Desktop/demo.wav \
--voice af_heart
```
The first invocation downloads Kokoro models, phoneme dictionaries, and voice embeddings; later runs reuse the cached assets.
### Swift
```swift
import FluidAudioTTS
let manager = TtSManager()
try await manager.initialize()
let audioData = try await manager.synthesize(text: "Hello from FluidAudio!")
let outputURL = URL(fileURLWithPath: "/tmp/demo.wav")
try audioData.write(to: outputURL)
```
Swap in `manager.initialize(models:)` when you want to preload only the long-form `.fifteenSecond` variant.
## Inspecting Chunk Metadata
```swift
let manager = TtSManager()
try await manager.initialize()
let detailed = try await manager.synthesizeDetailed(
text: "FluidAudio can report chunk splits for you.",
variantPreference: .fifteenSecond
)
for chunk in detailed.chunks {
print("Chunk #\(chunk.index) -> variant: \(chunk.variant), tokens: \(chunk.tokenCount)")
print(" text: \(chunk.text)")
}
```
`KokoroSynthesizer.SynthesisResult` also exposes `diagnostics` for per-run variant and audio footprint totals.
## SSML Support
Kokoro supports a subset of SSML tags for controlling pronunciation. See [SSML.md](SSML.md) for details.
## How It Differs From PocketTTS
| | Kokoro | PocketTTS |
|---|---|---|
| Text input | Phonemes (IPA via espeak) | Raw text (SentencePiece) |
| Voice conditioning | Style embedding vector | 125 audio prompt tokens |
| Generation | All frames at once | Frame-by-frame autoregressive |
| Flow matching target | Mel spectrogram | 32-dim latent per frame |
| Audio synthesis | Vocos vocoder | Mimi streaming codec |
| Latency to first audio | Must wait for full generation | ~80ms after prefill |
Kokoro parallelizes across time (fast total, but must wait for everything). PocketTTS is sequential across time (slower total, but audio starts immediately).
## Enable TTS in Your Project
### App/Library Development (Xcode & SwiftPM)
When adding FluidAudio to your Xcode project or Package.swift, select the **`FluidAudioWithTTS`** product:
**Xcode:**
1. File > Add Package Dependencies
2. Enter FluidAudio repository URL
3. Choose **`FluidAudioWithTTS`**
4. Add it to your app target
**Package.swift:**
```swift
dependencies: [
.package(url: "https://github.com/FluidInference/FluidAudio.git", from: "0.7.7"),
],
targets: [
.target(
name: "YourTarget",
dependencies: [
.product(name: "FluidAudioWithTTS", package: "FluidAudio")
]
)
]
```
**Import in your code:**
```swift
import FluidAudio // Core functionality (ASR, diarization, VAD)
import FluidAudioTTS // TTS features
```
### CLI Development
TTS support is enabled by default in the CLI:
```bash
swift run fluidaudio tts "Welcome to FluidAudio" --output ~/Desktop/demo.wav
```
+112
View File
@@ -0,0 +1,112 @@
# PocketTTS Swift Inference
How the Swift code generates speech from text.
## Files
| File | Role |
|------|------|
| `PocketTtsManager.swift` | Public API — `initialize()`, `synthesize()`, `synthesizeToFile()` |
| `PocketTtsModelStore.swift` | Loads and stores the 4 CoreML models + constants + voice data |
| `PocketTtsSynthesizer.swift` | Main synthesis loop — chunking, prefill, generation, output |
| `PocketTtsSynthesizer+KVCache.swift` | KV cache state, `prefillKVCache()`, `runCondStep()`, `runFlowLMStep()` |
| `PocketTtsSynthesizer+Flow.swift` | Flow decoder loop, `denormalize()`, `quantize()`, SeededRNG |
| `PocketTtsSynthesizer+Mimi.swift` | Mimi decoder state, `runMimiDecoder()`, `loadMimiInitialState()` |
| `PocketTtsConstantsLoader.swift` | Loads binary constants (embeddings, tokenizer, quantizer weights) |
| `PocketTtsConstants.swift` | All numeric constants (dimensions, thresholds, etc.) |
## Call Flow
```
PocketTtsManager.synthesize(text:)
|
v
PocketTtsSynthesizer.synthesize(text:voice:temperature:)
|
|-- chunkText() split text into <=50 token chunks
|-- loadMimiInitialState() load 23 streaming state tensors from disk
|
|-- FOR EACH CHUNK:
| |
| |-- tokenizer.encode() SentencePiece text → token IDs
| |-- embedTokens() table lookup: token ID → [1024] vector
| |-- prefillKVCache() feed 125 voice + N text tokens through cond_step
| | |
| | |-- emptyKVCacheState() fresh cache (6 layers × [2,1,512,16,64])
| | |-- runCondStep() × ~141 one token per call, updates cache
| |
| |-- GENERATE LOOP (until EOS or max frames):
| | |
| | |-- runFlowLMStep() → transformer_out [1,1024] + eos_logit
| | |-- flowDecode() → 32-dim latent
| | | |-- randn(32) * sqrt(temperature)
| | | |-- runFlowDecoderStep() × 8 Euler steps
| | | |-- latent += velocity * dt each step
| | |
| | |-- denormalize() latent * std + mean
| | |-- quantize() matmul [32] × [32,512] → [512]
| | |-- runMimiDecoder() [512] → 1920 audio samples
| | | updates 23 streaming state tensors
| | |
| | |-- createSequenceFromLatent() feed latent back for next frame
|
|-- concatenate all frames
|-- applyTtsPostProcessing() (optional de-essing)
|-- AudioWAV.data() wrap in WAV header (24kHz mono)
```
## Key State
### KV Cache (`KVCacheState`)
- 6 cache tensors `[2, 1, 512, 16, 64]` + 6 position counters
- Written during prefill (voice + text tokens)
- Read and extended during generation (one position per frame)
- **Reset per chunk** — each chunk gets a fresh cache
### Mimi State (`MimiState`)
- 23 tensors: convolution history, attention caches, overlap-add buffers
- Loaded once from `mimi_init_state/*.bin` files via `manifest.json`
- Updated after every `runMimiDecoder()` call — outputs feed back as next input
- **Continuous across chunks** — never reset, keeps audio seamless
## Text Chunking
Long text is split into chunks of <=50 tokens to fit the KV cache (512 positions, minus ~125 voice + ~25 overhead).
Splitting priority:
1. Sentence boundaries (`.!?`)
2. Clause boundaries (`,;:`)
3. Word boundaries (fallback)
`normalizeText()` also capitalizes, adds terminal punctuation, and pads short text with leading spaces for better prosody.
## EOS Detection
`runFlowLMStep()` returns an `eos_logit`. When it exceeds `-4.0`, the code generates a few extra frames (3 for short text, 1 for long) then stops.
## CoreML Details
- All 4 models loaded with `.cpuAndGPU` compute units (ANE float16 causes artifacts in Mimi state feedback)
- Models compiled from `.mlpackage``.mlmodelc` on first load, cached on disk
- `PocketTtsModelStore` is an actor — thread-safe access to loaded models
- Voice data cached per voice name to avoid reloading
## Usage
```swift
import FluidAudioTTS
let manager = PocketTtsManager()
try await manager.initialize()
let audioData = try await manager.synthesize(text: "Hello, world!")
try await manager.synthesizeToFile(
text: "Hello, world!",
outputURL: URL(fileURLWithPath: "/tmp/output.wav")
)
```
## License
CC-BY-4.0, inherited from [kyutai/pocket-tts](https://huggingface.co/kyutai/pocket-tts).
-108
View File
@@ -1,108 +0,0 @@
# Text-To-Speech (TTS) Code Examples
> **⚠️ Beta:** The TTS system is currently in beta and only supports American English. Additional language support is planned for future releases.
Quick recipes for running the Kokoro synthesis stack.
## Enable TTS in Your Project
### For App/Library Development (Xcode & SwiftPM)
When adding FluidAudio to your Xcode project or Package.swift, select the **`FluidAudioWithTTS`** product to include text-to-speech capabilities:
**Xcode:**
1. File → Add Package Dependencies
2. Enter FluidAudio repository URL
3. In the package product selection dialog, choose **`FluidAudioWithTTS`**
4. Add it to your app target
**Package.swift:**
```swift
dependencies: [
.package(url: "https://github.com/FluidInference/FluidAudio.git", from: "0.7.7"),
],
targets: [
.target(
name: "YourTarget",
dependencies: [
.product(name: "FluidAudioWithTTS", package: "FluidAudio")
]
)
]
```
**Import in your code:**
```swift
import FluidAudio // Core functionality (ASR, diarization, VAD)
import FluidAudioTTS // TTS features
```
### For CLI Development
When developing or running the FluidAudio CLI, TTS support is enabled by default.
**Terminal:**
```bash
swift run fluidaudio tts "Welcome to FluidAudio" --output ~/Desktop/demo.wav
# Or explicitly build/test the CLI with TTS
swift build
swift test
```
## CLI quick start
```bash
swift run fluidaudio tts "Welcome to FluidAudio text to speech" \
--output ~/Desktop/demo.wav \
--voice af_heart
```
The first invocation downloads Kokoro models, phoneme dictionaries, and voice embeddings; later runs reuse the
cached assets.
## Swift async usage
```swift
import FluidAudio
import Foundation
@main
struct DemoTTS {
static func main() async {
let manager = TtSManager()
do {
try await manager.initialize()
let audioData = try await manager.synthesize(text: "Hello from FluidAudio!")
let outputURL = URL(fileURLWithPath: "/tmp/fluidaudio-demo.wav")
try audioData.write(to: outputURL)
print("Saved synthesized audio to: \(outputURL.path)")
} catch {
print("Synthesis failed: \(error)")
}
}
}
```
Swap in `manager.initialize(models:)` when you want to preload only the long-form `.fifteenSecond` variant.
## Inspecting chunk metadata
```swift
let manager = TtSManager()
try await manager.initialize()
let detailed = try await manager.synthesizeDetailed(
text: "FluidAudio can report chunk splits for you.",
variantPreference: .fifteenSecond
)
for chunk in detailed.chunks {
print("Chunk #\(chunk.index) -> variant: \(chunk.variant), tokens: \(chunk.tokenCount)")
print(" text: \(chunk.text)")
}
```
`KokoroSynthesizer.SynthesisResult` also exposes `diagnostics` for per-run variant and audio footprint totals.
+17 -9
View File
@@ -201,7 +201,7 @@ public class DownloadUtils {
}
// Get all files recursively using HuggingFace API
var filesToDownload: [String] = []
var filesToDownload: [(path: String, size: Int)] = []
func listDirectory(path: String) async throws {
let apiPath = path.isEmpty ? "tree/main" : "tree/main/\(path)"
@@ -256,7 +256,8 @@ public class DownloadUtils {
|| itemPath.hasSuffix(".json") || itemPath.hasSuffix(".txt")
}
if shouldInclude {
filesToDownload.append(itemPath)
let fileSize = item["size"] as? Int ?? -1
filesToDownload.append((path: itemPath, size: fileSize))
}
}
}
@@ -267,11 +268,11 @@ public class DownloadUtils {
logger.info("Found \(filesToDownload.count) files to download")
// Download each file
for (index, filePath) in filesToDownload.enumerated() {
for (index, file) in filesToDownload.enumerated() {
// Strip subPath prefix when saving locally
var localPath = filePath
if let sub = subPath, filePath.hasPrefix("\(sub)/") {
localPath = String(filePath.dropFirst(sub.count + 1))
var localPath = file.path
if let sub = subPath, file.path.hasPrefix("\(sub)/") {
localPath = String(file.path.dropFirst(sub.count + 1))
}
let destPath = repoPath.appendingPathComponent(localPath)
@@ -286,8 +287,15 @@ public class DownloadUtils {
withIntermediateDirectories: true
)
// HuggingFace returns 500 for 0-byte files create empty file locally
if file.size == 0 {
FileManager.default.createFile(atPath: destPath.path, contents: Data())
continue
}
// Download file (use original path for HuggingFace URL)
let encodedFilePath = filePath.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? filePath
let encodedFilePath =
file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedFilePath)
let request = authorizedRequest(url: fileURL)
@@ -300,12 +308,12 @@ public class DownloadUtils {
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
throw HuggingFaceDownloadError.rateLimited(
statusCode: httpResponse.statusCode,
message: "Rate limited while downloading \(filePath)")
message: "Rate limited while downloading \(file.path)")
}
guard (200..<300).contains(httpResponse.statusCode) else {
throw HuggingFaceDownloadError.downloadFailed(
path: filePath,
path: file.path,
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
)
}
+31
View File
@@ -12,6 +12,7 @@ public enum Repo: String, CaseIterable {
case diarizer = "FluidInference/speaker-diarization-coreml"
case kokoro = "FluidInference/kokoro-82m-coreml"
case sortformer = "FluidInference/diar-streaming-sortformer-coreml"
case pocketTts = "FluidInference/pocket-tts-coreml"
/// Repository slug (without owner)
public var name: String {
@@ -36,6 +37,8 @@ public enum Repo: String, CaseIterable {
return "kokoro-82m-coreml"
case .sortformer:
return "diar-streaming-sortformer-coreml"
case .pocketTts:
return "pocket-tts-coreml"
}
}
@@ -78,6 +81,8 @@ public enum Repo: String, CaseIterable {
return "parakeet-eou-streaming/320ms"
case .sortformer:
return "sortformer"
case .pocketTts:
return "pocket-tts"
default:
return name
}
@@ -263,6 +268,30 @@ public enum ModelNames {
}
}
/// PocketTTS model names (flow-matching language model TTS)
public enum PocketTTS {
public static let condStep = "cond_step"
public static let flowlmStep = "flowlm_step"
public static let flowDecoder = "flow_decoder"
public static let mimiDecoder = "mimi_decoder_v2"
public static let condStepFile = condStep + ".mlmodelc"
public static let flowlmStepFile = flowlmStep + ".mlmodelc"
public static let flowDecoderFile = flowDecoder + ".mlmodelc"
public static let mimiDecoderFile = mimiDecoder + ".mlmodelc"
/// Directory containing binary constants, tokenizer, and voice data.
public static let constantsBinDir = "constants_bin"
public static let requiredModels: Set<String> = [
condStepFile,
flowlmStepFile,
flowDecoderFile,
mimiDecoderFile,
constantsBinDir,
]
}
/// TTS model names
public enum TTS {
@@ -328,6 +357,8 @@ public enum ModelNames {
return ModelNames.Diarizer.requiredModels
case .kokoro:
return ModelNames.TTS.requiredModels
case .pocketTts:
return ModelNames.PocketTTS.requiredModels
case .sortformer:
return ModelNames.Sortformer.requiredModels
}
+141 -4
View File
@@ -134,6 +134,7 @@ public struct TTS {
var text: String? = nil
var benchmarkMode = false
var deEss = true
var backend: TtsBackend = .kokoro
var i = 0
while i < arguments.count {
@@ -180,6 +181,19 @@ public struct TTS {
lexiconPath = arguments[i + 1]
i += 1
}
case "--backend":
if i + 1 < arguments.count {
let value = arguments[i + 1].lowercased()
switch value {
case "kokoro":
backend = .kokoro
case "pocket", "pockettts":
backend = .pocketTts
default:
logger.warning("Unknown backend '\(arguments[i + 1])'; using kokoro")
}
i += 1
}
case "--auto-download":
// No-op: downloads are always ensured by the CLI
()
@@ -214,12 +228,19 @@ public struct TTS {
return
}
if backend == .pocketTts {
await runPocketTts(
text: text, output: output, voice: voice, deEss: deEss,
metricsPath: metricsPath)
return
}
do {
// Timing buckets
let tStart = Date()
let customLexicon = try loadCustomLexicon(from: lexiconPath)
let manager = TtSManager(customLexicon: customLexicon)
let manager = KokoroTtsManager(customLexicon: customLexicon)
let requestedVoice = voice.trimmingCharacters(in: .whitespacesAndNewlines)
let voiceOverride = requestedVoice.isEmpty ? nil : requestedVoice
let preloadVoices = voiceOverride.map { Set([$0]) }
@@ -449,6 +470,121 @@ public struct TTS {
}
}
private static func runPocketTts(
text: String, output: String, voice: String, deEss: Bool,
metricsPath: String?
) async {
do {
let tStart = Date()
let pocketVoice =
voice == TtsConstants.recommendedVoice
? PocketTtsConstants.defaultVoice : voice
let manager = PocketTtsManager(defaultVoice: pocketVoice)
let tLoad0 = Date()
try await manager.initialize()
let tLoad1 = Date()
let tSynth0 = Date()
let wav = try await manager.synthesize(
text: text, voice: pocketVoice, deEss: deEss)
let tSynth1 = Date()
let outURL = {
let expanded = (output as NSString).expandingTildeInPath
if expanded.hasPrefix("/") {
return URL(fileURLWithPath: expanded)
}
let cwd = URL(
fileURLWithPath: FileManager.default.currentDirectoryPath,
isDirectory: true)
return cwd.appendingPathComponent(expanded)
}()
try FileManager.default.createDirectory(
at: outURL.deletingLastPathComponent(),
withIntermediateDirectories: true)
try wav.write(to: outURL)
let loadS = tLoad1.timeIntervalSince(tLoad0)
let synthS = tSynth1.timeIntervalSince(tSynth0)
let totalS = tSynth1.timeIntervalSince(tStart)
let sampleRate = Double(PocketTtsConstants.audioSampleRate)
let payload = max(0, wav.count - 44)
let audioSecs = Double(payload) / (sampleRate * 2.0)
let rtfx = synthS > 0 ? audioSecs / synthS : 0
logger.info("PocketTTS synthesis complete")
logger.info(" Load: \(String(format: "%.3f", loadS))s")
logger.info(" Synthesis: \(String(format: "%.3f", synthS))s")
logger.info(" Audio: \(String(format: "%.3f", audioSecs))s")
logger.info(" RTFx: \(String(format: "%.2f", rtfx))x")
logger.info(" Total: \(String(format: "%.3f", totalS))s")
logger.info(" Output: \(outURL.path)")
// ASR round-trip evaluation
if metricsPath != nil {
logger.info("--- Running ASR for TTS→STT evaluation ---")
var asrHypothesis: String? = nil
var werValue: Double? = nil
do {
let asrModels = try await AsrModels.downloadAndLoad()
let asr = AsrManager()
try await asr.initialize(models: asrModels)
let transcription = try await asr.transcribe(outURL)
asrHypothesis = transcription.text
let werMetrics = WERCalculator.calculateWERMetrics(
hypothesis: transcription.text, reference: text)
werValue = werMetrics.wer
logger.info("Reference: \(text)")
logger.info("Hypothesis: \(transcription.text)")
logger.info(String(format: "WER: %.1f%%", werValue! * 100))
asr.cleanup()
} catch {
logger.warning("ASR evaluation failed: \(error.localizedDescription)")
}
if let metricsPath {
var metricsDict: [String: Any] = [
"backend": "pockettts",
"text": text,
"voice": pocketVoice,
"output": outURL.path,
"model_load_time_s": loadS,
"inference_time_s": synthS,
"audio_duration_s": audioSecs,
"realtime_speed": rtfx,
"total_time_s": totalS,
]
if let asrHypothesis {
metricsDict["asr_hypothesis"] = asrHypothesis
}
if let werValue {
metricsDict["wer"] = werValue
}
let artifactsRoot = try ensureArtifactsRoot()
let mURL = resolveOutputURL(
metricsPath, artifactsRoot: artifactsRoot, expectsDirectory: false)
try FileManager.default.createDirectory(
at: mURL.deletingLastPathComponent(), withIntermediateDirectories: true)
let json = try JSONSerialization.data(
withJSONObject: metricsDict, options: [.prettyPrinted])
try json.write(to: mURL)
logger.info("Metrics saved: \(mURL.path)")
}
}
} catch {
logger.error("PocketTTS Error: \(error)")
print("PocketTTS failed: \(error)")
exit(1)
}
}
private static func printUsage() {
print(
"""
@@ -456,8 +592,9 @@ public struct TTS {
Options:
--output, -o Output WAV path (default: output.wav)
--voice, -v Voice name (default: af_heart)
--lexicon, -l Custom pronunciation lexicon file (word=phonemes format)
--voice, -v Voice name (default: af_heart for Kokoro, alba for PocketTTS)
--backend TTS backend: kokoro (default) or pocket
--lexicon, -l Custom pronunciation lexicon file (word=phonemes format, Kokoro only)
--benchmark Run a predefined benchmarking suite with multiple sentences
--variant Force Kokoro 5s or 15s model (values: 5s,15s)
--metrics Write timing metrics to a JSON file (also runs ASR for evaluation)
@@ -496,7 +633,7 @@ extension TTS {
) async {
do {
let customLexicon = try loadCustomLexicon(from: lexiconPath)
let manager = TtSManager(customLexicon: customLexicon)
let manager = KokoroTtsManager(customLexicon: customLexicon)
let requestedVoice = voice.trimmingCharacters(in: .whitespacesAndNewlines)
let normalizedVoice = requestedVoice.isEmpty ? nil : requestedVoice
let preloadVoices = normalizedVoice.map { Set([$0]) }
@@ -55,7 +55,7 @@ public struct KokoroSynthesizer {
static func currentModelCache() throws -> KokoroModelCache {
guard let cache = Context.modelCache else {
throw TTSError.processingFailed(
"KokoroSynthesizer requires a model cache context. Use TtSManager or withModelCache(_:operation:)."
"KokoroSynthesizer requires a model cache context. Use KokoroTtsManager or withModelCache(_:operation:)."
)
}
return cache
@@ -64,7 +64,7 @@ public struct KokoroSynthesizer {
static func currentLexiconAssets() throws -> LexiconAssetManager {
guard let assets = Context.lexiconAssets else {
throw TTSError.processingFailed(
"KokoroSynthesizer requires lexicon assets context. Use TtSManager or withLexiconAssets(_:operation:)."
"KokoroSynthesizer requires lexicon assets context. Use KokoroTtsManager or withLexiconAssets(_:operation:)."
)
}
return assets
@@ -2,20 +2,20 @@ import FluidAudio
import Foundation
import OSLog
/// Manages text-to-speech synthesis using the Kokoro CoreML model.
/// Manages text-to-speech synthesis using Kokoro CoreML models.
///
/// - Note: **Beta:** The TTS system is currently in beta and only supports American English.
/// Additional language support is planned for future releases.
///
/// Example usage:
/// ```swift
/// let manager = TtSManager()
/// let manager = KokoroTtsManager()
/// try await manager.initialize()
/// let audioData = try await manager.synthesize(text: "Hello, world!")
/// ```
public final class TtSManager {
public final class KokoroTtsManager {
private let logger = AppLogger(category: "TtSManager")
private let logger = AppLogger(category: "KokoroTtsManager")
private let modelCache: KokoroModelCache
private let lexiconAssets: LexiconAssetManager
@@ -80,7 +80,7 @@ public final class TtSManager {
try await KokoroSynthesizer.loadSimplePhonemeDictionary()
try await modelCache.loadModelsIfNeeded(variants: models.availableVariants)
isInitialized = true
logger.notice("TtSManager initialized with provided models")
logger.notice("KokoroTtsManager initialized with provided models")
}
public func initialize(preloadVoices: Set<String>? = nil) async throws {
@@ -0,0 +1,117 @@
import FluidAudio
import Foundation
import OSLog
/// Pre-loaded binary constants for PocketTTS inference.
public struct PocketTtsConstantsBundle: Sendable {
public let bosEmbedding: [Float]
public let textEmbedTable: [Float]
public let tokenizer: SentencePieceTokenizer
}
/// Pre-loaded voice conditioning data.
public struct PocketTtsVoiceData: Sendable {
/// Flattened audio prompt: [1, promptLength, 1024]
public let audioPrompt: [Float]
/// Number of voice conditioning tokens (typically 125).
public let promptLength: Int
}
/// Loads PocketTTS constants from raw `.bin` Float32 files on disk.
public enum PocketTtsConstantsLoader {
private static let logger = AppLogger(category: "PocketTtsConstantsLoader")
public enum LoadError: Error {
case fileNotFound(String)
case invalidSize(String, expected: Int, actual: Int)
case tokenizerLoadFailed(String)
}
/// Load all constants from the given directory.
public static func load(from directory: URL) throws -> PocketTtsConstantsBundle {
let constantsDir = directory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
let bosEmb = try loadFloatArray(
from: constantsDir.appendingPathComponent("bos_emb.bin"),
expectedCount: PocketTtsConstants.latentDim,
name: "bos_emb"
)
let embedTable = try loadFloatArray(
from: constantsDir.appendingPathComponent("text_embed_table.bin"),
expectedCount: PocketTtsConstants.vocabSize * PocketTtsConstants.embeddingDim,
name: "text_embed_table"
)
let tokenizerURL = constantsDir.appendingPathComponent("tokenizer.model")
guard FileManager.default.fileExists(atPath: tokenizerURL.path) else {
throw LoadError.fileNotFound("tokenizer.model")
}
let tokenizerData = try Data(contentsOf: tokenizerURL)
let tokenizer: SentencePieceTokenizer
do {
tokenizer = try SentencePieceTokenizer(modelData: tokenizerData)
} catch {
throw LoadError.tokenizerLoadFailed(error.localizedDescription)
}
logger.info("Loaded PocketTTS constants from \(directory.lastPathComponent)")
return PocketTtsConstantsBundle(
bosEmbedding: bosEmb,
textEmbedTable: embedTable,
tokenizer: tokenizer
)
}
/// Load voice conditioning data from the given directory.
///
/// HuggingFace layout: `constants_bin/<voice>_audio_prompt.bin`
public static func loadVoice(
_ voice: String, from directory: URL
) throws -> PocketTtsVoiceData {
// Sanitize voice name to prevent path traversal
let sanitized = voice.components(separatedBy: CharacterSet.alphanumerics.inverted).joined()
guard !sanitized.isEmpty else {
throw LoadError.fileNotFound("invalid voice name: \(voice)")
}
let constantsDir = directory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
let audioPrompt = try loadFloatArray(
from: constantsDir.appendingPathComponent("\(sanitized)_audio_prompt.bin"),
expectedCount: PocketTtsConstants.voicePromptLength * PocketTtsConstants.embeddingDim,
name: "\(sanitized)_audio_prompt"
)
logger.info("Loaded PocketTTS voice '\(sanitized)' conditioning data")
return PocketTtsVoiceData(
audioPrompt: audioPrompt,
promptLength: PocketTtsConstants.voicePromptLength
)
}
// MARK: - Private
/// Load a raw Float32 binary file into a [Float] array.
private static func loadFloatArray(
from url: URL, expectedCount: Int, name: String
) throws -> [Float] {
guard FileManager.default.fileExists(atPath: url.path) else {
throw LoadError.fileNotFound(name)
}
let data = try Data(contentsOf: url)
let actualCount = data.count / MemoryLayout<Float>.size
guard actualCount == expectedCount else {
throw LoadError.invalidSize(name, expected: expectedCount, actual: actualCount)
}
return data.withUnsafeBytes { rawBuffer in
let floatBuffer = rawBuffer.bindMemory(to: Float.self)
return Array(floatBuffer)
}
}
}
@@ -0,0 +1,72 @@
import FluidAudio
import Foundation
import OSLog
/// Downloads PocketTTS models and constants from HuggingFace.
public enum PocketTtsResourceDownloader {
private static let logger = AppLogger(category: "PocketTtsResourceDownloader")
/// Ensure all PocketTTS models are downloaded and return the cache directory.
public static func ensureModels() async throws -> URL {
let cacheDirectory = try cacheDirectory()
let modelsDirectory = cacheDirectory.appendingPathComponent(
PocketTtsConstants.defaultModelsSubdirectory)
let repoDir = modelsDirectory.appendingPathComponent(Repo.pocketTts.folderName)
// Check that all required directories exist (models + constants_bin)
let requiredModels = ModelNames.PocketTTS.requiredModels
let allPresent = requiredModels.allSatisfy { model in
FileManager.default.fileExists(
atPath: repoDir.appendingPathComponent(model).path)
}
if !allPresent {
logger.info("Downloading PocketTTS models from HuggingFace...")
try await DownloadUtils.downloadRepo(.pocketTts, to: modelsDirectory)
} else {
logger.info("PocketTTS models found in cache")
}
return repoDir
}
/// Ensure constants (binary blobs + tokenizer) are available.
public static func ensureConstants(repoDirectory: URL) throws -> PocketTtsConstantsBundle {
try PocketTtsConstantsLoader.load(from: repoDirectory)
}
/// Ensure voice conditioning data is available.
public static func ensureVoice(
_ voice: String, repoDirectory: URL
) throws -> PocketTtsVoiceData {
try PocketTtsConstantsLoader.loadVoice(voice, from: repoDirectory)
}
// MARK: - Private
private static func cacheDirectory() throws -> URL {
let baseDirectory: URL
#if os(macOS)
baseDirectory = FileManager.default.homeDirectoryForCurrentUser
.appendingPathComponent(".cache")
#else
guard
let first = FileManager.default.urls(
for: .cachesDirectory, in: .userDomainMask
).first
else {
throw TTSError.processingFailed("Failed to locate caches directory")
}
baseDirectory = first
#endif
let cacheDirectory = baseDirectory.appendingPathComponent("fluidaudio")
if !FileManager.default.fileExists(atPath: cacheDirectory.path) {
try FileManager.default.createDirectory(
at: cacheDirectory, withIntermediateDirectories: true)
}
return cacheDirectory
}
}
@@ -0,0 +1,133 @@
@preconcurrency import CoreML
import FluidAudio
import Foundation
import OSLog
/// Actor-based store for PocketTTS CoreML models and constants.
///
/// Manages loading and storing of the four CoreML models
/// (cond_step, flowlm_step, flow_decoder, mimi_decoder),
/// the binary constants bundle, and voice conditioning data.
public actor PocketTtsModelStore {
private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "PocketTtsModelStore")
private var condStepModel: MLModel?
private var flowlmStepModel: MLModel?
private var flowDecoderModel: MLModel?
private var mimiDecoderModel: MLModel?
private var constantsBundle: PocketTtsConstantsBundle?
private var voiceCache: [String: PocketTtsVoiceData] = [:]
private var repoDirectory: URL?
public init() {}
/// Load all four CoreML models and the constants bundle.
public func loadIfNeeded() async throws {
guard condStepModel == nil else { return }
let repoDir = try await PocketTtsResourceDownloader.ensureModels()
self.repoDirectory = repoDir
logger.info("Loading PocketTTS CoreML models...")
// Use CPU+GPU for all models to avoid ANE float16 precision loss.
// The ANE processes in native float16, which causes audible artifacts
// in the Mimi decoder's streaming state feedback loop and may degrade
// quality in the other models. CPU/GPU compute in float32 matches the
// Python reference implementation.
let config = MLModelConfiguration()
config.computeUnits = .cpuAndGPU
let loadStart = Date()
let modelFiles = [
ModelNames.PocketTTS.condStepFile,
ModelNames.PocketTTS.flowlmStepFile,
ModelNames.PocketTTS.flowDecoderFile,
ModelNames.PocketTTS.mimiDecoderFile,
]
var loadedModels: [MLModel] = []
for file in modelFiles {
let modelURL = repoDir.appendingPathComponent(file)
let model = try MLModel(contentsOf: modelURL, configuration: config)
loadedModels.append(model)
logger.info("Loaded \(file)")
}
condStepModel = loadedModels[0]
flowlmStepModel = loadedModels[1]
flowDecoderModel = loadedModels[2]
mimiDecoderModel = loadedModels[3]
let elapsed = Date().timeIntervalSince(loadStart)
logger.info("All PocketTTS models loaded in \(String(format: "%.2f", elapsed))s")
// Load constants
constantsBundle = try PocketTtsResourceDownloader.ensureConstants(
repoDirectory: repoDir)
logger.info("PocketTTS constants loaded")
}
/// The conditioning step model (KV cache prefill).
public func condStep() throws -> MLModel {
guard let model = condStepModel else {
throw TTSError.modelNotFound("PocketTTS cond_step model not loaded")
}
return model
}
/// The autoregressive generation step model.
public func flowlmStep() throws -> MLModel {
guard let model = flowlmStepModel else {
throw TTSError.modelNotFound("PocketTTS flowlm_step model not loaded")
}
return model
}
/// The LSD flow decoder model.
public func flowDecoder() throws -> MLModel {
guard let model = flowDecoderModel else {
throw TTSError.modelNotFound("PocketTTS flow_decoder model not loaded")
}
return model
}
/// The Mimi streaming audio decoder model.
public func mimiDecoder() throws -> MLModel {
guard let model = mimiDecoderModel else {
throw TTSError.modelNotFound("PocketTTS mimi_decoder model not loaded")
}
return model
}
/// The pre-loaded binary constants.
public func constants() throws -> PocketTtsConstantsBundle {
guard let bundle = constantsBundle else {
throw TTSError.modelNotFound("PocketTTS constants not loaded")
}
return bundle
}
/// The repository directory containing models and constants.
public func repoDir() throws -> URL {
guard let dir = repoDirectory else {
throw TTSError.modelNotFound("PocketTTS repository not loaded")
}
return dir
}
/// Load and cache voice conditioning data.
public func voiceData(for voice: String) throws -> PocketTtsVoiceData {
if let cached = voiceCache[voice] {
return cached
}
guard let repoDir = repoDirectory else {
throw TTSError.modelNotFound("PocketTTS repository not loaded")
}
let data = try PocketTtsResourceDownloader.ensureVoice(voice, repoDirectory: repoDir)
voiceCache[voice] = data
return data
}
}
@@ -0,0 +1,157 @@
@preconcurrency import CoreML
import FluidAudio
import Foundation
extension PocketTtsSynthesizer {
/// Run the flow decoder using Euler integration (LSD steps).
///
/// Converts transformer output to a 32-dimensional audio latent
/// via `numSteps` iterative denoising steps.
static func flowDecode(
transformerOut: MLMultiArray,
numSteps: Int,
temperature: Float,
model: MLModel,
rng: inout some RandomNumberGenerator
) async throws -> [Float] {
let latentDim = PocketTtsConstants.latentDim
let dt: Float = 1.0 / Float(numSteps)
// Initialize latent with scaled random noise: randn * sqrt(temperature)
var latent = [Float](repeating: 0, count: latentDim)
let scale = sqrtf(temperature)
for i in 0..<latentDim {
latent[i] = Float.gaussianRandom(using: &rng) * scale
}
// Flatten transformer_out from [1, 1, 1024] to [1, 1024]
let transformerFlat = try reshapeToFlat(transformerOut, dim: PocketTtsConstants.transformerDim)
// Euler integration: 8 steps from t=0 to t=1
for step in 0..<numSteps {
let sValue = Float(step) * dt
let tValue = Float(step + 1) * dt
let velocity = try await runFlowDecoderStep(
transformerOut: transformerFlat,
latent: latent,
s: sValue,
t: tValue,
model: model
)
// Euler step: latent += velocity * dt
for i in 0..<latentDim {
latent[i] += velocity[i] * dt
}
}
return latent
}
// MARK: - Private
/// Run a single flow decoder step.
private static func runFlowDecoderStep(
transformerOut: MLMultiArray,
latent: [Float],
s: Float,
t: Float,
model: MLModel
) async throws -> [Float] {
let latentDim = PocketTtsConstants.latentDim
// Create latent MLMultiArray [1, 32]
let latentArray = try MLMultiArray(
shape: [1, NSNumber(value: latentDim)], dataType: .float32)
let latentPtr = latentArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
latent.withUnsafeBufferPointer { buffer in
guard let base = buffer.baseAddress else { return }
latentPtr.update(from: base, count: latentDim)
}
// Create s and t MLMultiArrays [1, 1]
let sArray = try MLMultiArray(shape: [1, 1], dataType: .float32)
sArray[0] = NSNumber(value: s)
let tArray = try MLMultiArray(shape: [1, 1], dataType: .float32)
tArray[0] = NSNumber(value: t)
let inputDict: [String: Any] = [
"transformer_out": transformerOut,
"latent": latentArray,
"s": sArray,
"t": tArray,
]
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
// Extract velocity take the first (and likely only) output
let outputNames = Array(output.featureNames)
guard let velocityArray = output.featureValue(for: outputNames[0])?.multiArrayValue else {
throw TTSError.processingFailed("Missing flow decoder velocity output")
}
let velocityPtr = velocityArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
return Array(UnsafeBufferPointer(start: velocityPtr, count: latentDim))
}
/// Reshape a [1, 1, D] MLMultiArray to [1, D].
private static func reshapeToFlat(_ array: MLMultiArray, dim: Int) throws -> MLMultiArray {
let flat = try MLMultiArray(shape: [1, NSNumber(value: dim)], dataType: .float32)
let srcPtr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
let dstPtr = flat.dataPointer.bindMemory(to: Float.self, capacity: dim)
dstPtr.update(from: srcPtr, count: dim)
return flat
}
}
// MARK: - Seeded Random
/// Simple seeded random number generator (xoshiro256**).
///
/// Provides reproducible random sequences when a seed is set,
/// and falls back to system entropy when unseeded.
struct SeededRNG: RandomNumberGenerator {
private var state: (UInt64, UInt64, UInt64, UInt64)
init(seed: UInt64) {
// SplitMix64 to expand seed into 4-part state
var s = seed
func next() -> UInt64 {
s &+= 0x9E37_79B9_7F4A_7C15
var z = s
z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
return z ^ (z >> 31)
}
state = (next(), next(), next(), next())
}
mutating func next() -> UInt64 {
let result = rotl(state.1 &* 5, 7) &* 9
let t = state.1 << 17
state.2 ^= state.0
state.3 ^= state.1
state.1 ^= state.2
state.0 ^= state.3
state.2 ^= t
state.3 = rotl(state.3, 45)
return result
}
private func rotl(_ x: UInt64, _ k: Int) -> UInt64 {
(x << k) | (x >> (64 - k))
}
}
extension Float {
/// Generate a single sample from the standard normal distribution (Box-Muller transform).
static func gaussianRandom(using rng: inout some RandomNumberGenerator) -> Float {
let u1 = Float.random(in: Float.leastNonzeroMagnitude...1.0, using: &rng)
let u2 = Float.random(in: 0.0...1.0, using: &rng)
return sqrtf(-2.0 * logf(u1)) * cosf(2.0 * .pi * u2)
}
}
@@ -0,0 +1,176 @@
@preconcurrency import CoreML
import FluidAudio
import Foundation
extension PocketTtsSynthesizer {
/// Mutable KV cache state passed through conditioning and generation steps.
struct KVCacheState {
/// 6 KV cache arrays, each [2, 1, 200, 16, 64].
var caches: [MLMultiArray]
/// 6 position counters, each [1].
var positions: [MLMultiArray]
}
/// Create an empty KV cache state (all zeros, positions at 0).
static func emptyKVCacheState() throws -> KVCacheState {
let layers = PocketTtsConstants.kvCacheLayers
let shape: [NSNumber] = [
2, 1, NSNumber(value: PocketTtsConstants.kvCacheMaxLen), 16, 64,
]
var caches: [MLMultiArray] = []
var positions: [MLMultiArray] = []
caches.reserveCapacity(layers)
positions.reserveCapacity(layers)
for _ in 0..<layers {
let cache = try MLMultiArray(shape: shape, dataType: .float32)
let cachePtr = cache.dataPointer.bindMemory(
to: Float.self, capacity: cache.count)
cachePtr.initialize(repeating: 0, count: cache.count)
caches.append(cache)
let pos = try MLMultiArray(shape: [1], dataType: .float32)
pos[0] = NSNumber(value: Float(0))
positions.append(pos)
}
return KVCacheState(caches: caches, positions: positions)
}
/// Run the conditioning step model for a single token, updating the KV cache in place.
static func runCondStep(
conditioning: MLMultiArray,
state: inout KVCacheState,
model: MLModel
) async throws {
var inputDict: [String: Any] = [
"conditioning": conditioning
]
for i in 0..<PocketTtsConstants.kvCacheLayers {
inputDict["cache\(i)"] = state.caches[i]
inputDict["position\(i)"] = state.positions[i]
}
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
for i in 0..<PocketTtsConstants.kvCacheLayers {
guard let newCache = output.featureValue(for: CondStepKeys.cacheKeys[i])?.multiArrayValue
else {
throw TTSError.processingFailed(
"Missing cond_step cache output: \(CondStepKeys.cacheKeys[i])")
}
guard let newPos = output.featureValue(for: CondStepKeys.positionKeys[i])?.multiArrayValue
else {
throw TTSError.processingFailed(
"Missing cond_step position output: \(CondStepKeys.positionKeys[i])")
}
state.caches[i] = newCache
state.positions[i] = newPos
}
}
/// Prefill the KV cache with voice and text conditioning tokens.
///
/// Processes voice tokens first, then text tokens (critical ordering).
static func prefillKVCache(
voiceData: PocketTtsVoiceData,
textEmbeddings: [[Float]],
model: MLModel
) async throws -> KVCacheState {
var state = try emptyKVCacheState()
let dim = PocketTtsConstants.embeddingDim
// Voice tokens first (positions 0..124)
let voiceTokenCount = voiceData.promptLength
for tokenIdx in 0..<voiceTokenCount {
let token = try createConditioningToken(
from: voiceData.audioPrompt,
offset: tokenIdx * dim,
dim: dim
)
try await runCondStep(conditioning: token, state: &state, model: model)
}
// Text tokens next
for embedding in textEmbeddings {
let token = try createConditioningToken(from: embedding, offset: 0, dim: dim)
try await runCondStep(conditioning: token, state: &state, model: model)
}
let finalPos = state.positions[0][0].floatValue
logger.info("KV cache prefilled to position \(Int(finalPos))")
return state
}
/// Create a [1, 1, 1024] MLMultiArray from a float slice.
private static func createConditioningToken(
from source: [Float], offset: Int, dim: Int
) throws -> MLMultiArray {
let array = try MLMultiArray(
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
source.withUnsafeBufferPointer { buffer in
guard let base = buffer.baseAddress else { return }
ptr.update(from: base.advanced(by: offset), count: dim)
}
return array
}
/// Run the generation step model, returning transformer output and EOS logit.
static func runFlowLMStep(
sequence: MLMultiArray,
bosEmb: MLMultiArray,
state: inout KVCacheState,
model: MLModel
) async throws -> (transformerOut: MLMultiArray, eosLogit: Float) {
var inputDict: [String: Any] = [
"sequence": sequence,
"bos_emb": bosEmb,
]
for i in 0..<PocketTtsConstants.kvCacheLayers {
inputDict["cache\(i)"] = state.caches[i]
inputDict["position\(i)"] = state.positions[i]
}
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
// Extract transformer output
guard let transformerOut = output.featureValue(for: FlowLMStepKeys.transformerOut)?.multiArrayValue
else {
throw TTSError.processingFailed("Missing flowlm_step transformer output")
}
// Extract EOS logit
guard let eosArray = output.featureValue(for: FlowLMStepKeys.eosLogit)?.multiArrayValue
else {
throw TTSError.processingFailed("Missing flowlm_step EOS logit")
}
let eosLogit = eosArray[0].floatValue
// Update caches and positions
for i in 0..<PocketTtsConstants.kvCacheLayers {
guard
let newCache = output.featureValue(for: FlowLMStepKeys.cacheKeys[i])?.multiArrayValue
else {
throw TTSError.processingFailed(
"Missing flowlm_step cache output: \(FlowLMStepKeys.cacheKeys[i])")
}
guard let newPos = output.featureValue(for: FlowLMStepKeys.positionKeys[i])?.multiArrayValue
else {
throw TTSError.processingFailed(
"Missing flowlm_step position output: \(FlowLMStepKeys.positionKeys[i])")
}
state.caches[i] = newCache
state.positions[i] = newPos
}
return (transformerOut: transformerOut, eosLogit: eosLogit)
}
}
@@ -0,0 +1,159 @@
@preconcurrency import CoreML
import FluidAudio
import Foundation
extension PocketTtsSynthesizer {
/// Mutable streaming state for the Mimi audio decoder.
///
/// Contains 26 tensors that track convolutional history,
/// attention caches, and partial upsampling buffers.
struct MimiState {
var tensors: [String: MLMultiArray]
}
/// Create the initial Mimi decoder state from the constants directory.
///
/// Loads pre-computed initial state tensors from `.bin` files,
/// using `manifest.json` for shape metadata.
static func loadMimiInitialState(from repoDirectory: URL) throws -> MimiState {
let constantsDir = repoDirectory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
let stateDir = constantsDir.appendingPathComponent("mimi_init_state")
let manifestURL = constantsDir.appendingPathComponent("manifest.json")
// Parse manifest for mimi_init_state shapes
let manifestData = try Data(contentsOf: manifestURL)
guard let manifest = try JSONSerialization.jsonObject(with: manifestData) as? [String: Any],
let mimiManifest = manifest["mimi_init_state"] as? [String: Any]
else {
throw TTSError.processingFailed("Failed to parse mimi_init_state from manifest.json")
}
var tensors: [String: MLMultiArray] = [:]
for (name, info) in mimiManifest {
guard let infoDict = info as? [String: Any],
let shapeArray = infoDict["shape"] as? [Int],
let byteCount = infoDict["bytes"] as? Int
else {
continue
}
let shape = shapeArray.map { NSNumber(value: $0) }
let array = try MLMultiArray(shape: shape, dataType: .float32)
if byteCount > 0 && !shapeArray.contains(0) {
let binURL = stateDir.appendingPathComponent("\(name).bin")
let data = try Data(contentsOf: binURL)
let floatCount = byteCount / MemoryLayout<Float>.size
let dstPtr = array.dataPointer.bindMemory(to: Float.self, capacity: floatCount)
data.withUnsafeBytes { rawBuffer in
let srcPtr = rawBuffer.bindMemory(to: Float.self)
dstPtr.update(from: srcPtr.baseAddress!, count: floatCount)
}
}
tensors[name] = array
}
// Ensure offset scalars exist
for key in ["attn0_offset", "attn0_end_offset", "attn1_offset", "attn1_end_offset"] {
if tensors[key] == nil {
let scalar = try MLMultiArray(shape: [1], dataType: .float32)
scalar[0] = NSNumber(value: Float(0))
tensors[key] = scalar
}
}
return MimiState(tensors: tensors)
}
/// Clone a Mimi state for independent use.
static func cloneMimiState(_ state: MimiState) throws -> MimiState {
var newTensors: [String: MLMultiArray] = [:]
for (key, array) in state.tensors {
let copy = try MLMultiArray(shape: array.shape, dataType: array.dataType)
let byteSize: Int
switch array.dataType {
case .float16:
byteSize = array.count * MemoryLayout<UInt16>.size
default:
byteSize = array.count * MemoryLayout<Float>.size
}
if byteSize > 0 {
copy.dataPointer.copyMemory(from: array.dataPointer, byteCount: byteSize)
}
newTensors[key] = copy
}
return MimiState(tensors: newTensors)
}
/// Run the Mimi decoder for a single latent frame.
///
/// The model internally denormalizes and quantizes the 32-dim latent
/// before decoding to audio.
///
/// - Parameters:
/// - latent: The raw latent vector, shape [32].
/// - state: The streaming state (26 tensors), modified in place.
/// - model: The Mimi CoreML model.
/// - Returns: Audio samples for this frame (1920 samples = 80ms at 24kHz).
static func runMimiDecoder(
latent: [Float],
state: inout MimiState,
model: MLModel
) async throws -> [Float] {
// Create latent input: [1, 32]
let latentDim = PocketTtsConstants.latentDim
let latentArray = try MLMultiArray(
shape: [1, NSNumber(value: latentDim)], dataType: .float32)
let latentPtr = latentArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
latent.withUnsafeBufferPointer { buffer in
guard let base = buffer.baseAddress else { return }
latentPtr.update(from: base, count: latentDim)
}
// Build input dictionary
var inputDict: [String: Any] = ["latent": latentArray]
for (key, array) in state.tensors {
inputDict[key] = array
}
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
// Extract audio output [1, 1, 1920]
guard let audioArray = output.featureValue(for: MimiKeys.audioOutput)?.multiArrayValue else {
throw TTSError.processingFailed("Missing Mimi audio output")
}
let sampleCount = PocketTtsConstants.samplesPerFrame
let samples = readFloatArray(from: audioArray, count: sampleCount)
// Update streaming state
for (inputName, outputName) in mimiStateMapping {
guard let updated = output.featureValue(for: outputName)?.multiArrayValue else {
throw TTSError.processingFailed(
"Missing Mimi state output: \(outputName) (for \(inputName))")
}
state.tensors[inputName] = updated
}
return samples
}
/// Read Float values from an MLMultiArray, handling both float32 and float16 data types.
///
/// The Mimi decoder CoreML model outputs float16 tensors. Using `dataPointer` with
/// `Float.self` binding on float16 data produces garbage values. This method
/// uses the type-safe subscript accessor which handles conversion automatically.
private static func readFloatArray(from array: MLMultiArray, count: Int) -> [Float] {
if array.dataType == .float16 {
// Use subscript for correct float16 float32 conversion
return (0..<count).map { array[$0].floatValue }
}
// Fast path for float32: direct memory access
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: count)
return Array(UnsafeBufferPointer(start: ptr, count: count))
}
}
@@ -0,0 +1,88 @@
import Foundation
extension PocketTtsSynthesizer {
/// Result of a PocketTTS synthesis operation.
public struct SynthesisResult: Sendable {
/// WAV audio data (24kHz, 16-bit mono).
public let audio: Data
/// Raw Float32 audio samples.
public let samples: [Float]
/// Number of 80ms frames generated.
public let frameCount: Int
/// Generation step at which EOS was detected (nil if max length reached).
public let eosStep: Int?
}
/// CoreML output key names for the conditioning step model.
enum CondStepKeys {
static let cacheKeys: [String] = [
"new_cache_1_internal_tensor_assign_2",
"new_cache_3_internal_tensor_assign_2",
"new_cache_5_internal_tensor_assign_2",
"new_cache_7_internal_tensor_assign_2",
"new_cache_9_internal_tensor_assign_2",
"new_cache_internal_tensor_assign_2",
]
static let positionKeys: [String] = [
"var_445", "var_864", "var_1283", "var_1702", "var_2121", "var_2365",
]
}
/// CoreML output key names for the generation step model.
enum FlowLMStepKeys {
/// CoreML assigned this output the name "input" during model tracing
/// it is the transformer hidden state output, not an input tensor.
static let transformerOut = "input"
static let eosLogit = "var_2582"
static let cacheKeys: [String] = [
"new_cache_1_internal_tensor_assign_2",
"new_cache_3_internal_tensor_assign_2",
"new_cache_5_internal_tensor_assign_2",
"new_cache_7_internal_tensor_assign_2",
"new_cache_9_internal_tensor_assign_2",
"new_cache_internal_tensor_assign_2",
]
static let positionKeys: [String] = [
"var_458", "var_877", "var_1296", "var_1715", "var_2134", "var_2553",
]
}
/// CoreML output key names for the Mimi decoder model.
enum MimiKeys {
static let audioOutput = "var_821"
}
/// Mimi decoder streaming state key mappings (input name output name).
///
/// 26 state tensors including 3 zero-length tensors (res{0,1,2}_conv1_prev)
/// whose input and output names are identical pass-throughs.
static let mimiStateMapping: [(input: String, output: String)] = [
("upsample_partial", "var_82"),
("attn0_cache", "var_262"),
("attn0_offset", "var_840"),
("attn0_end_offset", "new_end_offset_1"),
("attn1_cache", "var_479"),
("attn1_offset", "var_843"),
("attn1_end_offset", "new_end_offset"),
("conv0_prev", "var_607"),
("conv0_first", "conv0_first"),
("convtr0_partial", "var_634"),
("res0_conv0_prev", "var_660"),
("res0_conv0_first", "res0_conv0_first"),
("res0_conv1_prev", "res0_conv1_prev"),
("res0_conv1_first", "res0_conv1_first"),
("convtr1_partial", "var_700"),
("res1_conv0_prev", "var_726"),
("res1_conv0_first", "res1_conv0_first"),
("res1_conv1_prev", "res1_conv1_prev"),
("res1_conv1_first", "res1_conv1_first"),
("convtr2_partial", "var_766"),
("res2_conv0_prev", "var_792"),
("res2_conv0_first", "res2_conv0_first"),
("res2_conv1_prev", "res2_conv1_prev"),
("res2_conv1_first", "res2_conv1_first"),
("conv_final_prev", "var_824"),
("conv_final_first", "conv_final_first"),
]
}
@@ -0,0 +1,535 @@
@preconcurrency import CoreML
import FluidAudio
import Foundation
import OSLog
/// PocketTTS flow-matching language model synthesizer.
///
/// Generates audio autoregressively: each generation step produces
/// an 80ms audio frame (1920 samples at 24kHz).
///
/// Long text is split into sentence-based chunks (50 tokens each)
/// to stay within the KV cache limit (200 positions).
///
/// Pipeline: text chunk [tokenize embed prefill KV generate flow decode mimi decode] WAV
public struct PocketTtsSynthesizer {
static let logger = AppLogger(category: "PocketTtsSynthesizer")
private enum Context {
@TaskLocal static var modelStore: PocketTtsModelStore?
}
static func withModelStore<T>(
_ store: PocketTtsModelStore,
operation: () async throws -> T
) async rethrows -> T {
try await Context.$modelStore.withValue(store) {
try await operation()
}
}
static func currentModelStore() throws -> PocketTtsModelStore {
guard let store = Context.modelStore else {
throw TTSError.processingFailed(
"PocketTtsSynthesizer requires a model store context.")
}
return store
}
// MARK: - Public API
/// Synthesize audio from text.
///
/// - Parameters:
/// - text: The text to synthesize.
/// - voice: Voice identifier (default: "alba").
/// - temperature: Generation temperature (default: 0.7).
/// - seed: Random seed for reproducibility (nil for random).
/// - deEss: Whether to apply de-essing post-processing.
/// - Returns: A synthesis result containing WAV audio data.
public static func synthesize(
text: String,
voice: String = PocketTtsConstants.defaultVoice,
temperature: Float = PocketTtsConstants.temperature,
seed: UInt64? = nil,
deEss: Bool = true
) async throws -> SynthesisResult {
let store = try currentModelStore()
logger.info("PocketTTS synthesizing: '\(text)'")
// 1. Load constants and voice
let constants = try await store.constants()
let voiceData = try await store.voiceData(for: voice)
// 2. Split text into chunks that fit within KV cache capacity
let chunks = chunkText(text, tokenizer: constants.tokenizer)
logger.info("Split into \(chunks.count) chunk(s)")
// 3. Set up random number generator (seeded or system entropy)
var rng = SeededRNG(seed: seed ?? UInt64.random(in: 0...UInt64.max))
// 4. Load models
let condModel = try await store.condStep()
let stepModel = try await store.flowlmStep()
let flowModel = try await store.flowDecoder()
let mimiModel = try await store.mimiDecoder()
// 5. Load Mimi initial state (continuous across chunks)
let repoDir = try await store.repoDir()
var mimiState = try loadMimiInitialState(from: repoDir)
// 6. Create BOS embedding
let bosEmb = try createBosEmbedding(constants.bosEmbedding)
// 7. Generate audio for each chunk
var audioChunks: [[Float]] = []
var lastEosStep: Int?
let genStart = Date()
for (chunkIdx, chunkText) in chunks.enumerated() {
let (normalizedChunk, framesAfterEos) = normalizeText(chunkText)
logger.info("Chunk \(chunkIdx + 1)/\(chunks.count): '\(normalizedChunk)'")
// Tokenize and embed this chunk
let tokenIds = constants.tokenizer.encode(normalizedChunk)
let textEmbeddings = embedTokens(tokenIds, constants: constants)
// Fresh KV cache per chunk
let prefillStart = Date()
var kvState = try await prefillKVCache(
voiceData: voiceData,
textEmbeddings: textEmbeddings,
model: condModel
)
let prefillElapsed = Date().timeIntervalSince(prefillStart)
logger.info(
"Chunk \(chunkIdx + 1) prefill: \(String(format: "%.2f", prefillElapsed))s (\(tokenIds.count) tokens)"
)
// Generation loop for this chunk
let maxGenLen = estimateMaxFrames(text: chunkText)
var eosStep: Int?
var sequence = try createNaNSequence()
let totalFramesAfterEos =
framesAfterEos + PocketTtsConstants.extraFramesAfterDetection
for step in 0..<maxGenLen {
let (transformerOut, eosLogit) = try await runFlowLMStep(
sequence: sequence,
bosEmb: bosEmb,
state: &kvState,
model: stepModel
)
if eosLogit > PocketTtsConstants.eosThreshold && eosStep == nil {
eosStep = step
logger.info("Chunk \(chunkIdx + 1) EOS at step \(step)")
}
if let eos = eosStep, step >= eos + totalFramesAfterEos {
break
}
let latent = try await flowDecode(
transformerOut: transformerOut,
numSteps: PocketTtsConstants.numLsdSteps,
temperature: temperature,
model: flowModel,
rng: &rng
)
// Mimi state is continuous across chunks
// (denormalize + quantize baked into mimi_decoder model)
let frameSamples = try await runMimiDecoder(
latent: latent,
state: &mimiState,
model: mimiModel
)
audioChunks.append(frameSamples)
sequence = try createSequenceFromLatent(latent)
if step % 20 == 0 {
logger.info("Chunk \(chunkIdx + 1) step \(step)...")
}
}
lastEosStep = eosStep
}
let genElapsed = Date().timeIntervalSince(genStart)
logger.info(
"Generated \(audioChunks.count) frames in \(String(format: "%.2f", genElapsed))s")
// 8. Concatenate audio (no peak normalization preserve natural levels)
var allSamples = audioChunks.flatMap { $0 }
// De-essing
if deEss {
AudioPostProcessor.applyTtsPostProcessing(
&allSamples,
sampleRate: Float(PocketTtsConstants.audioSampleRate),
deEssAmount: -3.0,
smoothing: false
)
}
// 9. Encode WAV
let audioData = try AudioWAV.data(
from: allSamples,
sampleRate: Double(PocketTtsConstants.audioSampleRate)
)
let duration = Double(allSamples.count) / Double(PocketTtsConstants.audioSampleRate)
logger.info("Audio duration: \(String(format: "%.2f", duration))s")
return SynthesisResult(
audio: audioData,
samples: allSamples,
frameCount: audioChunks.count,
eosStep: lastEosStep
)
}
// MARK: - Text Processing
/// Normalize a text chunk for PocketTTS (matching Python `prepare_text_prompt`).
static func normalizeText(_ text: String) -> (text: String, framesAfterEos: Int) {
var result = text.trimmingCharacters(in: .whitespacesAndNewlines)
// Collapse whitespace
result = result.replacingOccurrences(
of: "\\s+", with: " ", options: .regularExpression)
// Strip trailing clause punctuation (commas, semicolons, colons)
// before adding sentence-ending punctuation
while let last = result.last, ",;:".contains(last) {
result = String(result.dropLast())
}
result = result.trimmingCharacters(in: .whitespaces)
// Capitalize first letter
if let first = result.first, first.isLetter {
result = first.uppercased() + result.dropFirst()
}
// Add period if no terminal punctuation
if let last = result.last, !".!?".contains(last) {
result += "."
}
// Pad short texts for better prosody
let wordCount = result.split(separator: " ").count
let framesAfterEos: Int
if wordCount < PocketTtsConstants.shortTextWordThreshold {
result = String(repeating: " ", count: 8) + result
framesAfterEos = PocketTtsConstants.shortTextPadFrames
} else {
framesAfterEos = PocketTtsConstants.longTextExtraFrames
}
return (result, framesAfterEos)
}
/// Split text into chunks that fit within the KV cache token limit.
///
/// Splits at sentence boundaries (`.!?`) and groups sentences into chunks
/// where each chunk tokenizes to `maxTokensPerChunk` tokens.
/// Oversized single sentences are further split at word boundaries.
static func chunkText(
_ text: String,
tokenizer: SentencePieceTokenizer,
maxTokens: Int = PocketTtsConstants.maxTokensPerChunk
) -> [String] {
let normalized = text.trimmingCharacters(in: .whitespacesAndNewlines)
// If it fits in one chunk, return as-is
let tokenCount = tokenizer.encode(normalized).count
if tokenCount <= maxTokens {
return [normalized]
}
// Split into sentences at .!? boundaries
let sentences = splitSentences(normalized)
// Further split any oversized sentences at word boundaries
var pieces: [String] = []
for sentence in sentences {
let sentenceTokens = tokenizer.encode(sentence).count
if sentenceTokens <= maxTokens {
pieces.append(sentence)
} else {
pieces.append(contentsOf: splitOversizedSentence(sentence, tokenizer: tokenizer, maxTokens: maxTokens))
}
}
// Group pieces into chunks that fit the token limit
var chunks: [String] = []
var currentChunk = ""
for piece in pieces {
let candidate: String
if currentChunk.isEmpty {
candidate = piece
} else {
candidate = currentChunk + " " + piece
}
let candidateTokens = tokenizer.encode(candidate).count
if candidateTokens <= maxTokens {
currentChunk = candidate
} else {
if !currentChunk.isEmpty {
chunks.append(currentChunk)
}
currentChunk = piece
}
}
if !currentChunk.isEmpty {
chunks.append(currentChunk)
}
return chunks.isEmpty ? [normalized] : chunks
}
/// Split an oversized sentence to fit within the token limit.
///
/// First tries splitting at clause boundaries (commas, semicolons, colons).
/// Falls back to word-boundary splitting for clauses that still exceed the limit.
private static func splitOversizedSentence(
_ text: String,
tokenizer: SentencePieceTokenizer,
maxTokens: Int
) -> [String] {
// First try: split at clause boundaries
let clauseParts = splitAtClauseBoundaries(text)
// Group clause parts into chunks that fit
var result: [String] = []
var currentPart = ""
for part in clauseParts {
let candidate = currentPart.isEmpty ? part : currentPart + " " + part
let candidateTokens = tokenizer.encode(candidate).count
if candidateTokens <= maxTokens {
currentPart = candidate
} else {
if !currentPart.isEmpty {
result.append(currentPart)
}
// If single clause part still exceeds limit, split at word boundaries
if tokenizer.encode(part).count > maxTokens {
result.append(contentsOf: splitAtWordBoundaries(part, tokenizer: tokenizer, maxTokens: maxTokens))
currentPart = ""
} else {
currentPart = part
}
}
}
if !currentPart.isEmpty {
result.append(currentPart)
}
return result.isEmpty ? [text] : result
}
/// Split text at clause punctuation (commas, semicolons, colons).
///
/// Does not split at commas within numbers (e.g., "3,500").
private static func splitAtClauseBoundaries(_ text: String) -> [String] {
let clauseBreaks: Set<Character> = [",", ";", ":"]
var parts: [String] = []
var current = ""
let chars = Array(text)
for (i, char) in chars.enumerated() {
current.append(char)
guard clauseBreaks.contains(char) else { continue }
// Don't split at commas between digits (e.g., "3,500")
if char == "," {
let prevIsDigit = i > 0 && chars[i - 1].isNumber
let nextIsDigit = i + 1 < chars.count && chars[i + 1].isNumber
if prevIsDigit && nextIsDigit {
continue
}
}
let trimmed = current.trimmingCharacters(in: .whitespaces)
if !trimmed.isEmpty {
parts.append(trimmed)
}
current = ""
}
let trimmed = current.trimmingCharacters(in: .whitespaces)
if !trimmed.isEmpty {
parts.append(trimmed)
}
return parts
}
/// Split text at word boundaries to fit within the token limit.
private static func splitAtWordBoundaries(
_ text: String,
tokenizer: SentencePieceTokenizer,
maxTokens: Int
) -> [String] {
let words = text.split(separator: " ").map(String.init)
guard words.count > 1 else { return [text] }
var chunks: [String] = []
var currentWords: [String] = []
for word in words {
let candidate = (currentWords + [word]).joined(separator: " ")
let tokens = tokenizer.encode(candidate).count
if tokens > maxTokens && !currentWords.isEmpty {
chunks.append(currentWords.joined(separator: " "))
currentWords = [word]
} else {
currentWords.append(word)
}
}
if !currentWords.isEmpty {
chunks.append(currentWords.joined(separator: " "))
}
return chunks
}
/// Common abbreviations that end with a period but don't end a sentence.
private static let abbreviations: Set<String> = [
"dr", "mr", "mrs", "ms", "prof", "sr", "jr", "st", "vs", "etc",
"inc", "ltd", "co", "corp", "dept", "univ", "govt", "approx",
"avg", "est", "gen", "gov", "hon", "sgt", "cpl", "pvt", "capt",
"lt", "col", "maj", "cmdr", "adm", "rev", "sen", "rep",
]
/// Split text into sentences at `.!?` boundaries.
///
/// Handles abbreviations (e.g., "Dr.", "Prof.") by not splitting after them.
private static func splitSentences(_ text: String) -> [String] {
var sentences: [String] = []
var current = ""
let chars = Array(text)
for (i, char) in chars.enumerated() {
current.append(char)
guard ".!?".contains(char) else { continue }
// For periods, check if this is an abbreviation
if char == "." {
let trimmed = current.trimmingCharacters(in: .whitespaces)
// Get the last word before the period
let withoutPeriod = String(trimmed.dropLast())
let lastWord = withoutPeriod.split(separator: " ").last.map(String.init) ?? withoutPeriod
// Skip if it's a known abbreviation
if abbreviations.contains(lastWord.lowercased()) {
continue
}
// Skip if it's a single uppercase letter (e.g., "J." in initials)
if lastWord.count == 1, lastWord.first?.isUppercase == true {
continue
}
// Skip if followed by a digit (e.g., "3.5")
if i + 1 < chars.count, chars[i + 1].isNumber {
continue
}
}
let trimmed = current.trimmingCharacters(in: .whitespaces)
if !trimmed.isEmpty {
sentences.append(trimmed)
}
current = ""
}
// Remaining text without terminal punctuation
let trimmed = current.trimmingCharacters(in: .whitespaces)
if !trimmed.isEmpty {
sentences.append(trimmed)
}
return sentences
}
// MARK: - Embedding
/// Look up text token embeddings from the embedding table.
static func embedTokens(
_ tokenIds: [Int], constants: PocketTtsConstantsBundle
) -> [[Float]] {
let dim = PocketTtsConstants.embeddingDim
let vocabSize = PocketTtsConstants.vocabSize
return tokenIds.map { id in
guard id >= 0, id < vocabSize else {
logger.warning("Token ID \(id) out of range [0, \(vocabSize)), clamping")
let clampedId = min(max(id, 0), vocabSize - 1)
let offset = clampedId * dim
return Array(constants.textEmbedTable[offset..<(offset + dim)])
}
let offset = id * dim
return Array(constants.textEmbedTable[offset..<(offset + dim)])
}
}
// MARK: - Helpers
/// Estimate maximum generation frames based on text length.
private static func estimateMaxFrames(text: String) -> Int {
let wordCount = text.split(separator: " ").count
let genLenSec = Double(wordCount) + 2.0
return Int(genLenSec * 12.5)
}
/// Create the BOS embedding as an MLMultiArray [32].
private static func createBosEmbedding(_ bos: [Float]) throws -> MLMultiArray {
let dim = PocketTtsConstants.latentDim
let array = try MLMultiArray(shape: [NSNumber(value: dim)], dataType: .float32)
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
bos.withUnsafeBufferPointer { buffer in
guard let base = buffer.baseAddress else { return }
ptr.update(from: base, count: dim)
}
return array
}
/// Create a NaN-filled sequence [1, 1, 32] (signals BOS to the model).
private static func createNaNSequence() throws -> MLMultiArray {
let dim = PocketTtsConstants.latentDim
let array = try MLMultiArray(
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
for i in 0..<dim {
ptr[i] = .nan
}
return array
}
/// Create a sequence [1, 1, 32] from a latent vector.
private static func createSequenceFromLatent(_ latent: [Float]) throws -> MLMultiArray {
let dim = PocketTtsConstants.latentDim
let array = try MLMultiArray(
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
latent.withUnsafeBufferPointer { buffer in
guard let base = buffer.baseAddress else { return }
ptr.update(from: base, count: dim)
}
return array
}
}
@@ -0,0 +1,42 @@
import Foundation
/// Constants for the PocketTTS flow-matching language model TTS backend.
public enum PocketTtsConstants {
// MARK: - Audio
public static let audioSampleRate: Int = 24_000
public static let samplesPerFrame: Int = 1_920
// MARK: - Model dimensions
public static let latentDim: Int = 32
public static let transformerDim: Int = 1024
public static let vocabSize: Int = 4001
public static let embeddingDim: Int = 1024
// MARK: - Generation parameters
public static let numLsdSteps: Int = 8
public static let temperature: Float = 0.7
public static let eosThreshold: Float = -4.0
public static let shortTextPadFrames: Int = 3
public static let longTextExtraFrames: Int = 1
public static let extraFramesAfterDetection: Int = 2
public static let shortTextWordThreshold: Int = 5
public static let maxTokensPerChunk: Int = 50
// MARK: - KV cache
public static let kvCacheLayers: Int = 6
public static let kvCacheMaxLen: Int = 512
// MARK: - Voice
public static let defaultVoice: String = "alba"
public static let voicePromptLength: Int = 125
// MARK: - Repository
public static let defaultModelsSubdirectory: String = "Models"
}
@@ -0,0 +1,129 @@
import FluidAudio
import Foundation
import OSLog
/// Manages text-to-speech synthesis using PocketTTS CoreML models.
///
/// PocketTTS uses a flow-matching language model architecture that generates
/// audio autoregressively at 24kHz. Each generation step produces an 80ms
/// audio frame (1920 samples).
///
/// Example usage:
/// ```swift
/// let manager = PocketTtsManager()
/// try await manager.initialize()
/// let audioData = try await manager.synthesize(text: "Hello, world!")
/// ```
public actor PocketTtsManager {
private let logger = AppLogger(category: "PocketTtsManager")
private let modelStore: PocketTtsModelStore
private var defaultVoice: String
private var isInitialized = false
/// Creates a new PocketTTS manager.
///
/// - Parameters:
/// - defaultVoice: Default voice identifier (default: "alba").
public init(defaultVoice: String = PocketTtsConstants.defaultVoice) {
self.modelStore = PocketTtsModelStore()
self.defaultVoice = defaultVoice
}
public var isAvailable: Bool {
isInitialized
}
/// Initialize by downloading and loading all PocketTTS models.
public func initialize() async throws {
try await modelStore.loadIfNeeded()
isInitialized = true
logger.notice("PocketTtsManager initialized")
}
/// Synthesize text to WAV audio data.
///
/// - Parameters:
/// - text: The text to synthesize.
/// - voice: Voice identifier (default: uses the manager's default voice).
/// - temperature: Generation temperature (default: 0.7).
/// - deEss: Whether to apply de-essing post-processing (default: true).
/// - Returns: WAV audio data at 24kHz.
public func synthesize(
text: String,
voice: String? = nil,
temperature: Float = PocketTtsConstants.temperature,
deEss: Bool = true
) async throws -> Data {
guard isInitialized else {
throw TTSError.modelNotFound("PocketTTS model not initialized")
}
let selectedVoice = voice ?? defaultVoice
return try await PocketTtsSynthesizer.withModelStore(modelStore) {
let result = try await PocketTtsSynthesizer.synthesize(
text: text,
voice: selectedVoice,
temperature: temperature,
deEss: deEss
)
return result.audio
}
}
/// Synthesize text and return detailed results including frame count and EOS info.
public func synthesizeDetailed(
text: String,
voice: String? = nil,
temperature: Float = PocketTtsConstants.temperature,
deEss: Bool = true
) async throws -> PocketTtsSynthesizer.SynthesisResult {
guard isInitialized else {
throw TTSError.modelNotFound("PocketTTS model not initialized")
}
let selectedVoice = voice ?? defaultVoice
return try await PocketTtsSynthesizer.withModelStore(modelStore) {
try await PocketTtsSynthesizer.synthesize(
text: text,
voice: selectedVoice,
temperature: temperature,
deEss: deEss
)
}
}
/// Synthesize text and write the result directly to a file.
public func synthesizeToFile(
text: String,
outputURL: URL,
voice: String? = nil,
temperature: Float = PocketTtsConstants.temperature,
deEss: Bool = true
) async throws {
if FileManager.default.fileExists(atPath: outputURL.path) {
try FileManager.default.removeItem(at: outputURL)
}
let audioData = try await synthesize(
text: text,
voice: voice,
temperature: temperature,
deEss: deEss
)
try audioData.write(to: outputURL)
logger.notice("Saved synthesized audio to: \(outputURL.lastPathComponent)")
}
/// Update the default voice.
public func setDefaultVoice(_ voice: String) {
defaultVoice = voice
}
public func cleanup() {
isInitialized = false
}
}
@@ -0,0 +1,152 @@
import Foundation
/// Minimal protobuf parser for SentencePiece `.model` files.
///
/// Extracts only the vocabulary pieces (string + score) from the
/// `ModelProto` message, ignoring trainer/normalizer specs.
///
/// Wire format reference:
/// - Tag = (field_number << 3) | wire_type
/// - Wire type 0 = varint, 2 = length-delimited, 5 = 32-bit fixed
enum SentencePieceProto {
struct Piece: Sendable {
let piece: String
let score: Float
}
enum ParseError: Error {
case invalidData
case unexpectedEnd
case invalidUtf8
}
/// Parse a SentencePiece `.model` file and return the vocabulary pieces.
static func parse(_ data: Data) throws -> [Piece] {
var pieces: [Piece] = []
var offset = 0
let bytes = Array(data)
let count = bytes.count
while offset < count {
let (fieldNumber, wireType) = try readTag(bytes: bytes, count: count, offset: &offset)
switch wireType {
case 0:
// Varint skip
_ = try readVarint(bytes: bytes, count: count, offset: &offset)
case 1:
// 64-bit fixed skip 8 bytes
offset += 8
guard offset <= count else { throw ParseError.unexpectedEnd }
case 2:
// Length-delimited
let length = try readVarint(bytes: bytes, count: count, offset: &offset)
let end = offset + Int(length)
guard end <= count else { throw ParseError.unexpectedEnd }
if fieldNumber == 1 {
// Top-level field 1 = repeated SentencePiece message
let piece = try parsePiece(bytes: bytes, start: offset, end: end)
pieces.append(piece)
}
// Skip to end of this field regardless
offset = end
case 5:
// 32-bit fixed skip 4 bytes
offset += 4
guard offset <= count else { throw ParseError.unexpectedEnd }
default:
throw ParseError.invalidData
}
}
return pieces
}
// MARK: - Private
private static func parsePiece(bytes: [UInt8], start: Int, end: Int) throws -> Piece {
var offset = start
var piece: String?
var score: Float = 0
while offset < end {
let (fieldNumber, wireType) = try readTag(bytes: bytes, count: end, offset: &offset)
switch wireType {
case 0:
_ = try readVarint(bytes: bytes, count: end, offset: &offset)
case 1:
offset += 8
guard offset <= end else { throw ParseError.unexpectedEnd }
case 2:
let length = try readVarint(bytes: bytes, count: end, offset: &offset)
let fieldEnd = offset + Int(length)
guard fieldEnd <= end else { throw ParseError.unexpectedEnd }
if fieldNumber == 1 {
// SentencePiece.piece (string)
let slice = bytes[offset..<fieldEnd]
guard let str = String(bytes: slice, encoding: .utf8) else {
throw ParseError.invalidUtf8
}
piece = str
}
offset = fieldEnd
case 5:
if fieldNumber == 2 {
// SentencePiece.score (float)
guard offset + 4 <= end else { throw ParseError.unexpectedEnd }
score = readFloat32(bytes: bytes, offset: offset)
}
offset += 4
guard offset <= end else { throw ParseError.unexpectedEnd }
default:
throw ParseError.invalidData
}
}
return Piece(piece: piece ?? "", score: score)
}
private static func readTag(
bytes: [UInt8], count: Int, offset: inout Int
) throws -> (fieldNumber: Int, wireType: Int) {
let tag = try readVarint(bytes: bytes, count: count, offset: &offset)
let wireType = Int(tag & 0x07)
let fieldNumber = Int(tag >> 3)
return (fieldNumber, wireType)
}
private static func readVarint(
bytes: [UInt8], count: Int, offset: inout Int
) throws -> UInt64 {
var result: UInt64 = 0
var shift: UInt64 = 0
while offset < count {
let byte = bytes[offset]
offset += 1
result |= UInt64(byte & 0x7F) << shift
if byte & 0x80 == 0 {
return result
}
shift += 7
if shift >= 64 { throw ParseError.invalidData }
}
throw ParseError.unexpectedEnd
}
private static func readFloat32(bytes: [UInt8], offset: Int) -> Float {
var value: Float = 0
withUnsafeMutableBytes(of: &value) { ptr in
ptr[0] = bytes[offset]
ptr[1] = bytes[offset + 1]
ptr[2] = bytes[offset + 2]
ptr[3] = bytes[offset + 3]
}
return value
}
}
@@ -0,0 +1,120 @@
import Foundation
/// Minimal SentencePiece unigram tokenizer for PocketTTS.
///
/// Parses a `.model` protobuf to extract the vocabulary, then uses
/// Viterbi decoding to segment text into subword tokens.
public struct SentencePieceTokenizer: Sendable {
/// Vocabulary pieces with their log-probability scores.
private let pieces: [SentencePieceProto.Piece]
/// Lookup from piece string to token ID.
private let pieceToId: [String: Int]
/// Maximum piece length in UTF-8 scalars for early termination.
private let maxPieceLength: Int
/// The space replacement character used by SentencePiece.
private static let spaceMarker: Character = "\u{2581}"
public init(modelData: Data) throws {
let parsed = try SentencePieceProto.parse(modelData)
self.pieces = parsed
var lookup: [String: Int] = [:]
lookup.reserveCapacity(parsed.count)
var maxLen = 0
for (index, entry) in parsed.enumerated() {
lookup[entry.piece] = index
maxLen = max(maxLen, entry.piece.unicodeScalars.count)
}
self.pieceToId = lookup
self.maxPieceLength = maxLen
}
/// Tokenize text into token IDs using Viterbi unigram decoding.
///
/// Applies the standard SentencePiece normalization: replaces spaces
/// with `\u{2581}` and prepends `\u{2581}` to the input.
public func encode(_ text: String) -> [Int] {
guard !text.isEmpty else { return [] }
// Normalize: prepend space marker, replace spaces with marker
let normalized =
String(Self.spaceMarker)
+ text.replacingOccurrences(
of: " ", with: String(Self.spaceMarker))
return viterbiDecode(normalized)
}
// MARK: - Viterbi Decoding
/// Run Viterbi algorithm to find the highest-score segmentation.
///
/// For each position in the string, finds the best-scoring
/// vocabulary piece ending at that position.
private func viterbiDecode(_ text: String) -> [Int] {
let scalars = Array(text.unicodeScalars)
let n = scalars.count
guard n > 0 else { return [] }
// bestScore[i] = best log-probability score for text[0..<i]
// bestPiece[i] = (pieceId, startPosition) for the piece ending at i
let negInf: Float = -.infinity
var bestScore = [Float](repeating: negInf, count: n + 1)
var bestPiece = [(pieceId: Int, start: Int)](repeating: (0, 0), count: n + 1)
bestScore[0] = 0
// Build a string from scalars for substring matching
// We work with Unicode scalar offsets for correctness
for i in 0..<n {
guard bestScore[i] > negInf else { continue }
let maxLen = min(maxPieceLength, n - i)
for length in 1...maxLen {
let end = i + length
// Build candidate substring from scalars
let candidate = String(String.UnicodeScalarView(scalars[i..<end]))
guard let pieceId = pieceToId[candidate] else { continue }
let piece = pieces[pieceId]
let newScore = bestScore[i] + piece.score
if newScore > bestScore[end] {
bestScore[end] = newScore
bestPiece[end] = (pieceId: pieceId, start: i)
}
}
}
// Backtrack to collect token IDs
guard bestScore[n] > negInf else {
// Fallback: encode as individual characters
return fallbackEncode(scalars)
}
var ids: [Int] = []
var pos = n
while pos > 0 {
let (pieceId, start) = bestPiece[pos]
ids.append(pieceId)
pos = start
}
ids.reverse()
return ids
}
/// Fallback: encode each character as a separate token.
private func fallbackEncode(_ scalars: [Unicode.Scalar]) -> [Int] {
var ids: [Int] = []
for scalar in scalars {
let char = String(scalar)
if let id = pieceToId[char] {
ids.append(id)
}
// Unknown characters are silently dropped
}
return ids
}
}
@@ -0,0 +1,9 @@
import Foundation
/// Available TTS synthesis backends.
public enum TtsBackend: Sendable {
/// Kokoro 82M phoneme-based, multi-voice, chunk-oriented synthesis.
case kokoro
/// PocketTTS flow-matching language model, autoregressive streaming synthesis.
case pocketTts
}
@@ -13,8 +13,9 @@ final class FrameworkLinkTests: XCTestCase {
// Simply importing and using FluidAudio tests that ESpeakNG is properly linked
// If the framework structure is broken (binary name wrong, symlinks broken, etc),
// this would fail with dyld errors
let manager = TtSManager()
XCTAssertNotNil(manager, "TtSManager should be instantiable, meaning ESpeakNG framework is properly linked")
let manager = KokoroTtsManager()
XCTAssertNotNil(
manager, "KokoroTtsManager should be instantiable, meaning ESpeakNG framework is properly linked")
}
/// Test that the binary can actually be found by dyld at runtime.
@@ -23,7 +24,7 @@ final class FrameworkLinkTests: XCTestCase {
/// - Symlink chain is broken
/// - Framework not embedded correctly
func testFrameworkBinaryResolution() async throws {
let manager = TtSManager()
let manager = KokoroTtsManager()
XCTExpectFailure("Framework usage may fail in test environment without models", strict: false)
@@ -45,7 +46,7 @@ final class FrameworkLinkTests: XCTestCase {
/// Test that TTS functionality is accessible (requires ESpeakNG framework).
/// This ensures the framework is not just linked but properly functional.
func testTTSFrameworkFunctionality() async throws {
let manager = TtSManager()
let manager = KokoroTtsManager()
XCTExpectFailure("TTS may fail in CI without models", strict: false)
+4 -4
View File
@@ -4,13 +4,13 @@ import XCTest
import FluidAudioTTS
@testable import FluidAudio
final class TtSManagerTests: XCTestCase {
final class KokoroTtsManagerTests: XCTestCase {
var manager: TtSManager!
var manager: KokoroTtsManager!
override func setUp() {
super.setUp()
manager = TtSManager()
manager = KokoroTtsManager()
}
override func tearDown() {
@@ -133,7 +133,7 @@ final class TtSManagerTests: XCTestCase {
"Third text",
]
// Swift 6: TtSManager is not Sendable, so use sequential synthesis
// Swift 6: KokoroTtsManager is not Sendable, so use sequential synthesis
var results: [Data] = []
for text in texts {
let result = try await manager.synthesize(text: text)