mirror of
https://github.com/FluidInference/FluidAudio.git
synced 2026-05-12 20:20:36 +00:00
feat: add PocketTTS backend for lightweight text-to-speech (#273)
## Summary - Add PocketTTS as a new TTS backend — flow-matching language model with autoregressive streaming synthesis - Pure Swift implementation using 4 CoreML models (cond_step, flowlm_step, flow_decoder, mimi_decoder) - iOS 17 compatible — no `scaled_dot_product_attention` ops (avoids BNNS crash) - Add audio post-processor with de-esser for reducing sibilant harshness ## Test plan - [x] Short sentence: WER 0, 3.44s audio - [x] Long sentence: WER 0, 6.64s audio - [x] Fresh HuggingFace download works end-to-end - [x] iOS build succeeds (`xcodebuild -destination 'generic/platform=iOS'`) - [x] macOS build succeeds (`swift build -c release`)
This commit is contained in:
@@ -43,7 +43,8 @@ TDT models process audio in chunks (~15s with overlap) as batch operations. Fast
|
||||
|
||||
| Model | Description | Context |
|
||||
|-------|-------------|---------|
|
||||
| **Kokoro TTS** | Text-to-speech synthesis (82M params), 48 voices, minimal RAM usage on iOS. | First TTS backend added. |
|
||||
| **Kokoro TTS** | Text-to-speech synthesis (82M params), 48 voices, minimal RAM usage on iOS. Generates all frames at once via flow matching over mel spectrograms + Vocos vocoder. Requires espeak for phonemization. | First TTS backend added. |
|
||||
| **PocketTTS** | Second TTS backend (~155M params). Upgrade over Kokoro with much better dynamic audio chunking. No espeak dependency. | |
|
||||
|
||||
## Model Sources
|
||||
|
||||
@@ -58,3 +59,4 @@ TDT models process audio in chunks (~15s with overlap) as batch operations. Fast
|
||||
| Diarization (Pyannote) | [FluidInference/speaker-diarization-coreml](https://huggingface.co/FluidInference/speaker-diarization-coreml) |
|
||||
| Sortformer | [FluidInference/diar-streaming-sortformer-coreml](https://huggingface.co/FluidInference/diar-streaming-sortformer-coreml) |
|
||||
| Kokoro TTS | [FluidInference/kokoro-82m-coreml](https://huggingface.co/FluidInference/kokoro-82m-coreml) |
|
||||
| PocketTTS | [FluidInference/pocket-tts-coreml](https://huggingface.co/FluidInference/pocket-tts-coreml) |
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
# Kokoro: High-Quality Text-to-Speech
|
||||
|
||||
## Overview
|
||||
|
||||
Kokoro is a high-quality, English-only TTS backend. It generates the entire audio representation in one pass (all frames at once) using flow matching over mel spectrograms, then converts to audio with the Vocos vocoder.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
swift run fluidaudio tts "Welcome to FluidAudio text to speech" \
|
||||
--output ~/Desktop/demo.wav \
|
||||
--voice af_heart
|
||||
```
|
||||
|
||||
The first invocation downloads Kokoro models, phoneme dictionaries, and voice embeddings; later runs reuse the cached assets.
|
||||
|
||||
### Swift
|
||||
|
||||
```swift
|
||||
import FluidAudioTTS
|
||||
|
||||
let manager = TtSManager()
|
||||
try await manager.initialize()
|
||||
|
||||
let audioData = try await manager.synthesize(text: "Hello from FluidAudio!")
|
||||
|
||||
let outputURL = URL(fileURLWithPath: "/tmp/demo.wav")
|
||||
try audioData.write(to: outputURL)
|
||||
```
|
||||
|
||||
Swap in `manager.initialize(models:)` when you want to preload only the long-form `.fifteenSecond` variant.
|
||||
|
||||
## Inspecting Chunk Metadata
|
||||
|
||||
```swift
|
||||
let manager = TtSManager()
|
||||
try await manager.initialize()
|
||||
|
||||
let detailed = try await manager.synthesizeDetailed(
|
||||
text: "FluidAudio can report chunk splits for you.",
|
||||
variantPreference: .fifteenSecond
|
||||
)
|
||||
|
||||
for chunk in detailed.chunks {
|
||||
print("Chunk #\(chunk.index) -> variant: \(chunk.variant), tokens: \(chunk.tokenCount)")
|
||||
print(" text: \(chunk.text)")
|
||||
}
|
||||
```
|
||||
|
||||
`KokoroSynthesizer.SynthesisResult` also exposes `diagnostics` for per-run variant and audio footprint totals.
|
||||
|
||||
## SSML Support
|
||||
|
||||
Kokoro supports a subset of SSML tags for controlling pronunciation. See [SSML.md](SSML.md) for details.
|
||||
|
||||
## How It Differs From PocketTTS
|
||||
|
||||
| | Kokoro | PocketTTS |
|
||||
|---|---|---|
|
||||
| Text input | Phonemes (IPA via espeak) | Raw text (SentencePiece) |
|
||||
| Voice conditioning | Style embedding vector | 125 audio prompt tokens |
|
||||
| Generation | All frames at once | Frame-by-frame autoregressive |
|
||||
| Flow matching target | Mel spectrogram | 32-dim latent per frame |
|
||||
| Audio synthesis | Vocos vocoder | Mimi streaming codec |
|
||||
| Latency to first audio | Must wait for full generation | ~80ms after prefill |
|
||||
|
||||
Kokoro parallelizes across time (fast total, but must wait for everything). PocketTTS is sequential across time (slower total, but audio starts immediately).
|
||||
|
||||
## Enable TTS in Your Project
|
||||
|
||||
### App/Library Development (Xcode & SwiftPM)
|
||||
|
||||
When adding FluidAudio to your Xcode project or Package.swift, select the **`FluidAudioWithTTS`** product:
|
||||
|
||||
**Xcode:**
|
||||
1. File > Add Package Dependencies
|
||||
2. Enter FluidAudio repository URL
|
||||
3. Choose **`FluidAudioWithTTS`**
|
||||
4. Add it to your app target
|
||||
|
||||
**Package.swift:**
|
||||
```swift
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/FluidInference/FluidAudio.git", from: "0.7.7"),
|
||||
],
|
||||
targets: [
|
||||
.target(
|
||||
name: "YourTarget",
|
||||
dependencies: [
|
||||
.product(name: "FluidAudioWithTTS", package: "FluidAudio")
|
||||
]
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
**Import in your code:**
|
||||
```swift
|
||||
import FluidAudio // Core functionality (ASR, diarization, VAD)
|
||||
import FluidAudioTTS // TTS features
|
||||
```
|
||||
|
||||
### CLI Development
|
||||
|
||||
TTS support is enabled by default in the CLI:
|
||||
|
||||
```bash
|
||||
swift run fluidaudio tts "Welcome to FluidAudio" --output ~/Desktop/demo.wav
|
||||
```
|
||||
@@ -0,0 +1,112 @@
|
||||
# PocketTTS Swift Inference
|
||||
|
||||
How the Swift code generates speech from text.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Role |
|
||||
|------|------|
|
||||
| `PocketTtsManager.swift` | Public API — `initialize()`, `synthesize()`, `synthesizeToFile()` |
|
||||
| `PocketTtsModelStore.swift` | Loads and stores the 4 CoreML models + constants + voice data |
|
||||
| `PocketTtsSynthesizer.swift` | Main synthesis loop — chunking, prefill, generation, output |
|
||||
| `PocketTtsSynthesizer+KVCache.swift` | KV cache state, `prefillKVCache()`, `runCondStep()`, `runFlowLMStep()` |
|
||||
| `PocketTtsSynthesizer+Flow.swift` | Flow decoder loop, `denormalize()`, `quantize()`, SeededRNG |
|
||||
| `PocketTtsSynthesizer+Mimi.swift` | Mimi decoder state, `runMimiDecoder()`, `loadMimiInitialState()` |
|
||||
| `PocketTtsConstantsLoader.swift` | Loads binary constants (embeddings, tokenizer, quantizer weights) |
|
||||
| `PocketTtsConstants.swift` | All numeric constants (dimensions, thresholds, etc.) |
|
||||
|
||||
## Call Flow
|
||||
|
||||
```
|
||||
PocketTtsManager.synthesize(text:)
|
||||
|
|
||||
v
|
||||
PocketTtsSynthesizer.synthesize(text:voice:temperature:)
|
||||
|
|
||||
|-- chunkText() split text into <=50 token chunks
|
||||
|-- loadMimiInitialState() load 23 streaming state tensors from disk
|
||||
|
|
||||
|-- FOR EACH CHUNK:
|
||||
| |
|
||||
| |-- tokenizer.encode() SentencePiece text → token IDs
|
||||
| |-- embedTokens() table lookup: token ID → [1024] vector
|
||||
| |-- prefillKVCache() feed 125 voice + N text tokens through cond_step
|
||||
| | |
|
||||
| | |-- emptyKVCacheState() fresh cache (6 layers × [2,1,512,16,64])
|
||||
| | |-- runCondStep() × ~141 one token per call, updates cache
|
||||
| |
|
||||
| |-- GENERATE LOOP (until EOS or max frames):
|
||||
| | |
|
||||
| | |-- runFlowLMStep() → transformer_out [1,1024] + eos_logit
|
||||
| | |-- flowDecode() → 32-dim latent
|
||||
| | | |-- randn(32) * sqrt(temperature)
|
||||
| | | |-- runFlowDecoderStep() × 8 Euler steps
|
||||
| | | |-- latent += velocity * dt each step
|
||||
| | |
|
||||
| | |-- denormalize() latent * std + mean
|
||||
| | |-- quantize() matmul [32] × [32,512] → [512]
|
||||
| | |-- runMimiDecoder() [512] → 1920 audio samples
|
||||
| | | updates 23 streaming state tensors
|
||||
| | |
|
||||
| | |-- createSequenceFromLatent() feed latent back for next frame
|
||||
|
|
||||
|-- concatenate all frames
|
||||
|-- applyTtsPostProcessing() (optional de-essing)
|
||||
|-- AudioWAV.data() wrap in WAV header (24kHz mono)
|
||||
```
|
||||
|
||||
## Key State
|
||||
|
||||
### KV Cache (`KVCacheState`)
|
||||
- 6 cache tensors `[2, 1, 512, 16, 64]` + 6 position counters
|
||||
- Written during prefill (voice + text tokens)
|
||||
- Read and extended during generation (one position per frame)
|
||||
- **Reset per chunk** — each chunk gets a fresh cache
|
||||
|
||||
### Mimi State (`MimiState`)
|
||||
- 23 tensors: convolution history, attention caches, overlap-add buffers
|
||||
- Loaded once from `mimi_init_state/*.bin` files via `manifest.json`
|
||||
- Updated after every `runMimiDecoder()` call — outputs feed back as next input
|
||||
- **Continuous across chunks** — never reset, keeps audio seamless
|
||||
|
||||
## Text Chunking
|
||||
|
||||
Long text is split into chunks of <=50 tokens to fit the KV cache (512 positions, minus ~125 voice + ~25 overhead).
|
||||
|
||||
Splitting priority:
|
||||
1. Sentence boundaries (`.!?`)
|
||||
2. Clause boundaries (`,;:`)
|
||||
3. Word boundaries (fallback)
|
||||
|
||||
`normalizeText()` also capitalizes, adds terminal punctuation, and pads short text with leading spaces for better prosody.
|
||||
|
||||
## EOS Detection
|
||||
|
||||
`runFlowLMStep()` returns an `eos_logit`. When it exceeds `-4.0`, the code generates a few extra frames (3 for short text, 1 for long) then stops.
|
||||
|
||||
## CoreML Details
|
||||
|
||||
- All 4 models loaded with `.cpuAndGPU` compute units (ANE float16 causes artifacts in Mimi state feedback)
|
||||
- Models compiled from `.mlpackage` → `.mlmodelc` on first load, cached on disk
|
||||
- `PocketTtsModelStore` is an actor — thread-safe access to loaded models
|
||||
- Voice data cached per voice name to avoid reloading
|
||||
|
||||
## Usage
|
||||
|
||||
```swift
|
||||
import FluidAudioTTS
|
||||
|
||||
let manager = PocketTtsManager()
|
||||
try await manager.initialize()
|
||||
|
||||
let audioData = try await manager.synthesize(text: "Hello, world!")
|
||||
|
||||
try await manager.synthesizeToFile(
|
||||
text: "Hello, world!",
|
||||
outputURL: URL(fileURLWithPath: "/tmp/output.wav")
|
||||
)
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
CC-BY-4.0, inherited from [kyutai/pocket-tts](https://huggingface.co/kyutai/pocket-tts).
|
||||
@@ -1,108 +0,0 @@
|
||||
# Text-To-Speech (TTS) Code Examples
|
||||
|
||||
> **⚠️ Beta:** The TTS system is currently in beta and only supports American English. Additional language support is planned for future releases.
|
||||
|
||||
Quick recipes for running the Kokoro synthesis stack.
|
||||
|
||||
## Enable TTS in Your Project
|
||||
|
||||
### For App/Library Development (Xcode & SwiftPM)
|
||||
|
||||
When adding FluidAudio to your Xcode project or Package.swift, select the **`FluidAudioWithTTS`** product to include text-to-speech capabilities:
|
||||
|
||||
**Xcode:**
|
||||
1. File → Add Package Dependencies
|
||||
2. Enter FluidAudio repository URL
|
||||
3. In the package product selection dialog, choose **`FluidAudioWithTTS`**
|
||||
4. Add it to your app target
|
||||
|
||||
**Package.swift:**
|
||||
```swift
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/FluidInference/FluidAudio.git", from: "0.7.7"),
|
||||
],
|
||||
targets: [
|
||||
.target(
|
||||
name: "YourTarget",
|
||||
dependencies: [
|
||||
.product(name: "FluidAudioWithTTS", package: "FluidAudio")
|
||||
]
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
**Import in your code:**
|
||||
```swift
|
||||
import FluidAudio // Core functionality (ASR, diarization, VAD)
|
||||
import FluidAudioTTS // TTS features
|
||||
```
|
||||
|
||||
### For CLI Development
|
||||
|
||||
When developing or running the FluidAudio CLI, TTS support is enabled by default.
|
||||
|
||||
**Terminal:**
|
||||
```bash
|
||||
swift run fluidaudio tts "Welcome to FluidAudio" --output ~/Desktop/demo.wav
|
||||
|
||||
# Or explicitly build/test the CLI with TTS
|
||||
swift build
|
||||
swift test
|
||||
```
|
||||
|
||||
## CLI quick start
|
||||
|
||||
```bash
|
||||
swift run fluidaudio tts "Welcome to FluidAudio text to speech" \
|
||||
--output ~/Desktop/demo.wav \
|
||||
--voice af_heart
|
||||
```
|
||||
|
||||
The first invocation downloads Kokoro models, phoneme dictionaries, and voice embeddings; later runs reuse the
|
||||
cached assets.
|
||||
|
||||
## Swift async usage
|
||||
|
||||
```swift
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
@main
|
||||
struct DemoTTS {
|
||||
static func main() async {
|
||||
let manager = TtSManager()
|
||||
|
||||
do {
|
||||
try await manager.initialize()
|
||||
let audioData = try await manager.synthesize(text: "Hello from FluidAudio!")
|
||||
|
||||
let outputURL = URL(fileURLWithPath: "/tmp/fluidaudio-demo.wav")
|
||||
try audioData.write(to: outputURL)
|
||||
print("Saved synthesized audio to: \(outputURL.path)")
|
||||
} catch {
|
||||
print("Synthesis failed: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Swap in `manager.initialize(models:)` when you want to preload only the long-form `.fifteenSecond` variant.
|
||||
|
||||
## Inspecting chunk metadata
|
||||
|
||||
```swift
|
||||
let manager = TtSManager()
|
||||
try await manager.initialize()
|
||||
|
||||
let detailed = try await manager.synthesizeDetailed(
|
||||
text: "FluidAudio can report chunk splits for you.",
|
||||
variantPreference: .fifteenSecond
|
||||
)
|
||||
|
||||
for chunk in detailed.chunks {
|
||||
print("Chunk #\(chunk.index) -> variant: \(chunk.variant), tokens: \(chunk.tokenCount)")
|
||||
print(" text: \(chunk.text)")
|
||||
}
|
||||
```
|
||||
|
||||
`KokoroSynthesizer.SynthesisResult` also exposes `diagnostics` for per-run variant and audio footprint totals.
|
||||
@@ -201,7 +201,7 @@ public class DownloadUtils {
|
||||
}
|
||||
|
||||
// Get all files recursively using HuggingFace API
|
||||
var filesToDownload: [String] = []
|
||||
var filesToDownload: [(path: String, size: Int)] = []
|
||||
|
||||
func listDirectory(path: String) async throws {
|
||||
let apiPath = path.isEmpty ? "tree/main" : "tree/main/\(path)"
|
||||
@@ -256,7 +256,8 @@ public class DownloadUtils {
|
||||
|| itemPath.hasSuffix(".json") || itemPath.hasSuffix(".txt")
|
||||
}
|
||||
if shouldInclude {
|
||||
filesToDownload.append(itemPath)
|
||||
let fileSize = item["size"] as? Int ?? -1
|
||||
filesToDownload.append((path: itemPath, size: fileSize))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -267,11 +268,11 @@ public class DownloadUtils {
|
||||
logger.info("Found \(filesToDownload.count) files to download")
|
||||
|
||||
// Download each file
|
||||
for (index, filePath) in filesToDownload.enumerated() {
|
||||
for (index, file) in filesToDownload.enumerated() {
|
||||
// Strip subPath prefix when saving locally
|
||||
var localPath = filePath
|
||||
if let sub = subPath, filePath.hasPrefix("\(sub)/") {
|
||||
localPath = String(filePath.dropFirst(sub.count + 1))
|
||||
var localPath = file.path
|
||||
if let sub = subPath, file.path.hasPrefix("\(sub)/") {
|
||||
localPath = String(file.path.dropFirst(sub.count + 1))
|
||||
}
|
||||
let destPath = repoPath.appendingPathComponent(localPath)
|
||||
|
||||
@@ -286,8 +287,15 @@ public class DownloadUtils {
|
||||
withIntermediateDirectories: true
|
||||
)
|
||||
|
||||
// HuggingFace returns 500 for 0-byte files — create empty file locally
|
||||
if file.size == 0 {
|
||||
FileManager.default.createFile(atPath: destPath.path, contents: Data())
|
||||
continue
|
||||
}
|
||||
|
||||
// Download file (use original path for HuggingFace URL)
|
||||
let encodedFilePath = filePath.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? filePath
|
||||
let encodedFilePath =
|
||||
file.path.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? file.path
|
||||
let fileURL = try ModelRegistry.resolveModel(repo.remotePath, encodedFilePath)
|
||||
let request = authorizedRequest(url: fileURL)
|
||||
|
||||
@@ -300,12 +308,12 @@ public class DownloadUtils {
|
||||
if httpResponse.statusCode == 429 || httpResponse.statusCode == 503 {
|
||||
throw HuggingFaceDownloadError.rateLimited(
|
||||
statusCode: httpResponse.statusCode,
|
||||
message: "Rate limited while downloading \(filePath)")
|
||||
message: "Rate limited while downloading \(file.path)")
|
||||
}
|
||||
|
||||
guard (200..<300).contains(httpResponse.statusCode) else {
|
||||
throw HuggingFaceDownloadError.downloadFailed(
|
||||
path: filePath,
|
||||
path: file.path,
|
||||
underlying: NSError(domain: "HTTP", code: httpResponse.statusCode)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ public enum Repo: String, CaseIterable {
|
||||
case diarizer = "FluidInference/speaker-diarization-coreml"
|
||||
case kokoro = "FluidInference/kokoro-82m-coreml"
|
||||
case sortformer = "FluidInference/diar-streaming-sortformer-coreml"
|
||||
case pocketTts = "FluidInference/pocket-tts-coreml"
|
||||
|
||||
/// Repository slug (without owner)
|
||||
public var name: String {
|
||||
@@ -36,6 +37,8 @@ public enum Repo: String, CaseIterable {
|
||||
return "kokoro-82m-coreml"
|
||||
case .sortformer:
|
||||
return "diar-streaming-sortformer-coreml"
|
||||
case .pocketTts:
|
||||
return "pocket-tts-coreml"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +81,8 @@ public enum Repo: String, CaseIterable {
|
||||
return "parakeet-eou-streaming/320ms"
|
||||
case .sortformer:
|
||||
return "sortformer"
|
||||
case .pocketTts:
|
||||
return "pocket-tts"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
@@ -263,6 +268,30 @@ public enum ModelNames {
|
||||
}
|
||||
}
|
||||
|
||||
/// PocketTTS model names (flow-matching language model TTS)
|
||||
public enum PocketTTS {
|
||||
public static let condStep = "cond_step"
|
||||
public static let flowlmStep = "flowlm_step"
|
||||
public static let flowDecoder = "flow_decoder"
|
||||
public static let mimiDecoder = "mimi_decoder_v2"
|
||||
|
||||
public static let condStepFile = condStep + ".mlmodelc"
|
||||
public static let flowlmStepFile = flowlmStep + ".mlmodelc"
|
||||
public static let flowDecoderFile = flowDecoder + ".mlmodelc"
|
||||
public static let mimiDecoderFile = mimiDecoder + ".mlmodelc"
|
||||
|
||||
/// Directory containing binary constants, tokenizer, and voice data.
|
||||
public static let constantsBinDir = "constants_bin"
|
||||
|
||||
public static let requiredModels: Set<String> = [
|
||||
condStepFile,
|
||||
flowlmStepFile,
|
||||
flowDecoderFile,
|
||||
mimiDecoderFile,
|
||||
constantsBinDir,
|
||||
]
|
||||
}
|
||||
|
||||
/// TTS model names
|
||||
public enum TTS {
|
||||
|
||||
@@ -328,6 +357,8 @@ public enum ModelNames {
|
||||
return ModelNames.Diarizer.requiredModels
|
||||
case .kokoro:
|
||||
return ModelNames.TTS.requiredModels
|
||||
case .pocketTts:
|
||||
return ModelNames.PocketTTS.requiredModels
|
||||
case .sortformer:
|
||||
return ModelNames.Sortformer.requiredModels
|
||||
}
|
||||
|
||||
@@ -134,6 +134,7 @@ public struct TTS {
|
||||
var text: String? = nil
|
||||
var benchmarkMode = false
|
||||
var deEss = true
|
||||
var backend: TtsBackend = .kokoro
|
||||
|
||||
var i = 0
|
||||
while i < arguments.count {
|
||||
@@ -180,6 +181,19 @@ public struct TTS {
|
||||
lexiconPath = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--backend":
|
||||
if i + 1 < arguments.count {
|
||||
let value = arguments[i + 1].lowercased()
|
||||
switch value {
|
||||
case "kokoro":
|
||||
backend = .kokoro
|
||||
case "pocket", "pockettts":
|
||||
backend = .pocketTts
|
||||
default:
|
||||
logger.warning("Unknown backend '\(arguments[i + 1])'; using kokoro")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
case "--auto-download":
|
||||
// No-op: downloads are always ensured by the CLI
|
||||
()
|
||||
@@ -214,12 +228,19 @@ public struct TTS {
|
||||
return
|
||||
}
|
||||
|
||||
if backend == .pocketTts {
|
||||
await runPocketTts(
|
||||
text: text, output: output, voice: voice, deEss: deEss,
|
||||
metricsPath: metricsPath)
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
// Timing buckets
|
||||
let tStart = Date()
|
||||
|
||||
let customLexicon = try loadCustomLexicon(from: lexiconPath)
|
||||
let manager = TtSManager(customLexicon: customLexicon)
|
||||
let manager = KokoroTtsManager(customLexicon: customLexicon)
|
||||
let requestedVoice = voice.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let voiceOverride = requestedVoice.isEmpty ? nil : requestedVoice
|
||||
let preloadVoices = voiceOverride.map { Set([$0]) }
|
||||
@@ -449,6 +470,121 @@ public struct TTS {
|
||||
}
|
||||
}
|
||||
|
||||
private static func runPocketTts(
|
||||
text: String, output: String, voice: String, deEss: Bool,
|
||||
metricsPath: String?
|
||||
) async {
|
||||
do {
|
||||
let tStart = Date()
|
||||
let pocketVoice =
|
||||
voice == TtsConstants.recommendedVoice
|
||||
? PocketTtsConstants.defaultVoice : voice
|
||||
let manager = PocketTtsManager(defaultVoice: pocketVoice)
|
||||
|
||||
let tLoad0 = Date()
|
||||
try await manager.initialize()
|
||||
let tLoad1 = Date()
|
||||
|
||||
let tSynth0 = Date()
|
||||
let wav = try await manager.synthesize(
|
||||
text: text, voice: pocketVoice, deEss: deEss)
|
||||
let tSynth1 = Date()
|
||||
|
||||
let outURL = {
|
||||
let expanded = (output as NSString).expandingTildeInPath
|
||||
if expanded.hasPrefix("/") {
|
||||
return URL(fileURLWithPath: expanded)
|
||||
}
|
||||
let cwd = URL(
|
||||
fileURLWithPath: FileManager.default.currentDirectoryPath,
|
||||
isDirectory: true)
|
||||
return cwd.appendingPathComponent(expanded)
|
||||
}()
|
||||
try FileManager.default.createDirectory(
|
||||
at: outURL.deletingLastPathComponent(),
|
||||
withIntermediateDirectories: true)
|
||||
try wav.write(to: outURL)
|
||||
|
||||
let loadS = tLoad1.timeIntervalSince(tLoad0)
|
||||
let synthS = tSynth1.timeIntervalSince(tSynth0)
|
||||
let totalS = tSynth1.timeIntervalSince(tStart)
|
||||
let sampleRate = Double(PocketTtsConstants.audioSampleRate)
|
||||
let payload = max(0, wav.count - 44)
|
||||
let audioSecs = Double(payload) / (sampleRate * 2.0)
|
||||
let rtfx = synthS > 0 ? audioSecs / synthS : 0
|
||||
|
||||
logger.info("PocketTTS synthesis complete")
|
||||
logger.info(" Load: \(String(format: "%.3f", loadS))s")
|
||||
logger.info(" Synthesis: \(String(format: "%.3f", synthS))s")
|
||||
logger.info(" Audio: \(String(format: "%.3f", audioSecs))s")
|
||||
logger.info(" RTFx: \(String(format: "%.2f", rtfx))x")
|
||||
logger.info(" Total: \(String(format: "%.3f", totalS))s")
|
||||
logger.info(" Output: \(outURL.path)")
|
||||
|
||||
// ASR round-trip evaluation
|
||||
if metricsPath != nil {
|
||||
logger.info("--- Running ASR for TTS→STT evaluation ---")
|
||||
var asrHypothesis: String? = nil
|
||||
var werValue: Double? = nil
|
||||
|
||||
do {
|
||||
let asrModels = try await AsrModels.downloadAndLoad()
|
||||
let asr = AsrManager()
|
||||
try await asr.initialize(models: asrModels)
|
||||
|
||||
let transcription = try await asr.transcribe(outURL)
|
||||
asrHypothesis = transcription.text
|
||||
|
||||
let werMetrics = WERCalculator.calculateWERMetrics(
|
||||
hypothesis: transcription.text, reference: text)
|
||||
werValue = werMetrics.wer
|
||||
|
||||
logger.info("Reference: \(text)")
|
||||
logger.info("Hypothesis: \(transcription.text)")
|
||||
logger.info(String(format: "WER: %.1f%%", werValue! * 100))
|
||||
|
||||
asr.cleanup()
|
||||
} catch {
|
||||
logger.warning("ASR evaluation failed: \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
if let metricsPath {
|
||||
var metricsDict: [String: Any] = [
|
||||
"backend": "pockettts",
|
||||
"text": text,
|
||||
"voice": pocketVoice,
|
||||
"output": outURL.path,
|
||||
"model_load_time_s": loadS,
|
||||
"inference_time_s": synthS,
|
||||
"audio_duration_s": audioSecs,
|
||||
"realtime_speed": rtfx,
|
||||
"total_time_s": totalS,
|
||||
]
|
||||
if let asrHypothesis {
|
||||
metricsDict["asr_hypothesis"] = asrHypothesis
|
||||
}
|
||||
if let werValue {
|
||||
metricsDict["wer"] = werValue
|
||||
}
|
||||
|
||||
let artifactsRoot = try ensureArtifactsRoot()
|
||||
let mURL = resolveOutputURL(
|
||||
metricsPath, artifactsRoot: artifactsRoot, expectsDirectory: false)
|
||||
try FileManager.default.createDirectory(
|
||||
at: mURL.deletingLastPathComponent(), withIntermediateDirectories: true)
|
||||
let json = try JSONSerialization.data(
|
||||
withJSONObject: metricsDict, options: [.prettyPrinted])
|
||||
try json.write(to: mURL)
|
||||
logger.info("Metrics saved: \(mURL.path)")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
logger.error("PocketTTS Error: \(error)")
|
||||
print("PocketTTS failed: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
private static func printUsage() {
|
||||
print(
|
||||
"""
|
||||
@@ -456,8 +592,9 @@ public struct TTS {
|
||||
|
||||
Options:
|
||||
--output, -o Output WAV path (default: output.wav)
|
||||
--voice, -v Voice name (default: af_heart)
|
||||
--lexicon, -l Custom pronunciation lexicon file (word=phonemes format)
|
||||
--voice, -v Voice name (default: af_heart for Kokoro, alba for PocketTTS)
|
||||
--backend TTS backend: kokoro (default) or pocket
|
||||
--lexicon, -l Custom pronunciation lexicon file (word=phonemes format, Kokoro only)
|
||||
--benchmark Run a predefined benchmarking suite with multiple sentences
|
||||
--variant Force Kokoro 5s or 15s model (values: 5s,15s)
|
||||
--metrics Write timing metrics to a JSON file (also runs ASR for evaluation)
|
||||
@@ -496,7 +633,7 @@ extension TTS {
|
||||
) async {
|
||||
do {
|
||||
let customLexicon = try loadCustomLexicon(from: lexiconPath)
|
||||
let manager = TtSManager(customLexicon: customLexicon)
|
||||
let manager = KokoroTtsManager(customLexicon: customLexicon)
|
||||
let requestedVoice = voice.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let normalizedVoice = requestedVoice.isEmpty ? nil : requestedVoice
|
||||
let preloadVoices = normalizedVoice.map { Set([$0]) }
|
||||
|
||||
+2
-2
@@ -55,7 +55,7 @@ public struct KokoroSynthesizer {
|
||||
static func currentModelCache() throws -> KokoroModelCache {
|
||||
guard let cache = Context.modelCache else {
|
||||
throw TTSError.processingFailed(
|
||||
"KokoroSynthesizer requires a model cache context. Use TtSManager or withModelCache(_:operation:)."
|
||||
"KokoroSynthesizer requires a model cache context. Use KokoroTtsManager or withModelCache(_:operation:)."
|
||||
)
|
||||
}
|
||||
return cache
|
||||
@@ -64,7 +64,7 @@ public struct KokoroSynthesizer {
|
||||
static func currentLexiconAssets() throws -> LexiconAssetManager {
|
||||
guard let assets = Context.lexiconAssets else {
|
||||
throw TTSError.processingFailed(
|
||||
"KokoroSynthesizer requires lexicon assets context. Use TtSManager or withLexiconAssets(_:operation:)."
|
||||
"KokoroSynthesizer requires lexicon assets context. Use KokoroTtsManager or withLexiconAssets(_:operation:)."
|
||||
)
|
||||
}
|
||||
return assets
|
||||
|
||||
+5
-5
@@ -2,20 +2,20 @@ import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// Manages text-to-speech synthesis using the Kokoro CoreML model.
|
||||
/// Manages text-to-speech synthesis using Kokoro CoreML models.
|
||||
///
|
||||
/// - Note: **Beta:** The TTS system is currently in beta and only supports American English.
|
||||
/// Additional language support is planned for future releases.
|
||||
///
|
||||
/// Example usage:
|
||||
/// ```swift
|
||||
/// let manager = TtSManager()
|
||||
/// let manager = KokoroTtsManager()
|
||||
/// try await manager.initialize()
|
||||
/// let audioData = try await manager.synthesize(text: "Hello, world!")
|
||||
/// ```
|
||||
public final class TtSManager {
|
||||
public final class KokoroTtsManager {
|
||||
|
||||
private let logger = AppLogger(category: "TtSManager")
|
||||
private let logger = AppLogger(category: "KokoroTtsManager")
|
||||
private let modelCache: KokoroModelCache
|
||||
private let lexiconAssets: LexiconAssetManager
|
||||
|
||||
@@ -80,7 +80,7 @@ public final class TtSManager {
|
||||
try await KokoroSynthesizer.loadSimplePhonemeDictionary()
|
||||
try await modelCache.loadModelsIfNeeded(variants: models.availableVariants)
|
||||
isInitialized = true
|
||||
logger.notice("TtSManager initialized with provided models")
|
||||
logger.notice("KokoroTtsManager initialized with provided models")
|
||||
}
|
||||
|
||||
public func initialize(preloadVoices: Set<String>? = nil) async throws {
|
||||
@@ -0,0 +1,117 @@
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// Pre-loaded binary constants for PocketTTS inference.
|
||||
public struct PocketTtsConstantsBundle: Sendable {
|
||||
public let bosEmbedding: [Float]
|
||||
public let textEmbedTable: [Float]
|
||||
public let tokenizer: SentencePieceTokenizer
|
||||
}
|
||||
|
||||
/// Pre-loaded voice conditioning data.
|
||||
public struct PocketTtsVoiceData: Sendable {
|
||||
/// Flattened audio prompt: [1, promptLength, 1024]
|
||||
public let audioPrompt: [Float]
|
||||
/// Number of voice conditioning tokens (typically 125).
|
||||
public let promptLength: Int
|
||||
}
|
||||
|
||||
/// Loads PocketTTS constants from raw `.bin` Float32 files on disk.
|
||||
public enum PocketTtsConstantsLoader {
|
||||
|
||||
private static let logger = AppLogger(category: "PocketTtsConstantsLoader")
|
||||
|
||||
public enum LoadError: Error {
|
||||
case fileNotFound(String)
|
||||
case invalidSize(String, expected: Int, actual: Int)
|
||||
case tokenizerLoadFailed(String)
|
||||
}
|
||||
|
||||
/// Load all constants from the given directory.
|
||||
public static func load(from directory: URL) throws -> PocketTtsConstantsBundle {
|
||||
let constantsDir = directory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
|
||||
|
||||
let bosEmb = try loadFloatArray(
|
||||
from: constantsDir.appendingPathComponent("bos_emb.bin"),
|
||||
expectedCount: PocketTtsConstants.latentDim,
|
||||
name: "bos_emb"
|
||||
)
|
||||
let embedTable = try loadFloatArray(
|
||||
from: constantsDir.appendingPathComponent("text_embed_table.bin"),
|
||||
expectedCount: PocketTtsConstants.vocabSize * PocketTtsConstants.embeddingDim,
|
||||
name: "text_embed_table"
|
||||
)
|
||||
|
||||
let tokenizerURL = constantsDir.appendingPathComponent("tokenizer.model")
|
||||
guard FileManager.default.fileExists(atPath: tokenizerURL.path) else {
|
||||
throw LoadError.fileNotFound("tokenizer.model")
|
||||
}
|
||||
let tokenizerData = try Data(contentsOf: tokenizerURL)
|
||||
let tokenizer: SentencePieceTokenizer
|
||||
do {
|
||||
tokenizer = try SentencePieceTokenizer(modelData: tokenizerData)
|
||||
} catch {
|
||||
throw LoadError.tokenizerLoadFailed(error.localizedDescription)
|
||||
}
|
||||
|
||||
logger.info("Loaded PocketTTS constants from \(directory.lastPathComponent)")
|
||||
|
||||
return PocketTtsConstantsBundle(
|
||||
bosEmbedding: bosEmb,
|
||||
textEmbedTable: embedTable,
|
||||
tokenizer: tokenizer
|
||||
)
|
||||
}
|
||||
|
||||
/// Load voice conditioning data from the given directory.
|
||||
///
|
||||
/// HuggingFace layout: `constants_bin/<voice>_audio_prompt.bin`
|
||||
public static func loadVoice(
|
||||
_ voice: String, from directory: URL
|
||||
) throws -> PocketTtsVoiceData {
|
||||
// Sanitize voice name to prevent path traversal
|
||||
let sanitized = voice.components(separatedBy: CharacterSet.alphanumerics.inverted).joined()
|
||||
guard !sanitized.isEmpty else {
|
||||
throw LoadError.fileNotFound("invalid voice name: \(voice)")
|
||||
}
|
||||
|
||||
let constantsDir = directory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
|
||||
|
||||
let audioPrompt = try loadFloatArray(
|
||||
from: constantsDir.appendingPathComponent("\(sanitized)_audio_prompt.bin"),
|
||||
expectedCount: PocketTtsConstants.voicePromptLength * PocketTtsConstants.embeddingDim,
|
||||
name: "\(sanitized)_audio_prompt"
|
||||
)
|
||||
|
||||
logger.info("Loaded PocketTTS voice '\(sanitized)' conditioning data")
|
||||
|
||||
return PocketTtsVoiceData(
|
||||
audioPrompt: audioPrompt,
|
||||
promptLength: PocketTtsConstants.voicePromptLength
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
/// Load a raw Float32 binary file into a [Float] array.
|
||||
private static func loadFloatArray(
|
||||
from url: URL, expectedCount: Int, name: String
|
||||
) throws -> [Float] {
|
||||
guard FileManager.default.fileExists(atPath: url.path) else {
|
||||
throw LoadError.fileNotFound(name)
|
||||
}
|
||||
|
||||
let data = try Data(contentsOf: url)
|
||||
let actualCount = data.count / MemoryLayout<Float>.size
|
||||
|
||||
guard actualCount == expectedCount else {
|
||||
throw LoadError.invalidSize(name, expected: expectedCount, actual: actualCount)
|
||||
}
|
||||
|
||||
return data.withUnsafeBytes { rawBuffer in
|
||||
let floatBuffer = rawBuffer.bindMemory(to: Float.self)
|
||||
return Array(floatBuffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// Downloads PocketTTS models and constants from HuggingFace.
|
||||
public enum PocketTtsResourceDownloader {
|
||||
|
||||
private static let logger = AppLogger(category: "PocketTtsResourceDownloader")
|
||||
|
||||
/// Ensure all PocketTTS models are downloaded and return the cache directory.
|
||||
public static func ensureModels() async throws -> URL {
|
||||
let cacheDirectory = try cacheDirectory()
|
||||
let modelsDirectory = cacheDirectory.appendingPathComponent(
|
||||
PocketTtsConstants.defaultModelsSubdirectory)
|
||||
|
||||
let repoDir = modelsDirectory.appendingPathComponent(Repo.pocketTts.folderName)
|
||||
|
||||
// Check that all required directories exist (models + constants_bin)
|
||||
let requiredModels = ModelNames.PocketTTS.requiredModels
|
||||
let allPresent = requiredModels.allSatisfy { model in
|
||||
FileManager.default.fileExists(
|
||||
atPath: repoDir.appendingPathComponent(model).path)
|
||||
}
|
||||
|
||||
if !allPresent {
|
||||
logger.info("Downloading PocketTTS models from HuggingFace...")
|
||||
try await DownloadUtils.downloadRepo(.pocketTts, to: modelsDirectory)
|
||||
} else {
|
||||
logger.info("PocketTTS models found in cache")
|
||||
}
|
||||
|
||||
return repoDir
|
||||
}
|
||||
|
||||
/// Ensure constants (binary blobs + tokenizer) are available.
|
||||
public static func ensureConstants(repoDirectory: URL) throws -> PocketTtsConstantsBundle {
|
||||
try PocketTtsConstantsLoader.load(from: repoDirectory)
|
||||
}
|
||||
|
||||
/// Ensure voice conditioning data is available.
|
||||
public static func ensureVoice(
|
||||
_ voice: String, repoDirectory: URL
|
||||
) throws -> PocketTtsVoiceData {
|
||||
try PocketTtsConstantsLoader.loadVoice(voice, from: repoDirectory)
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private static func cacheDirectory() throws -> URL {
|
||||
let baseDirectory: URL
|
||||
#if os(macOS)
|
||||
baseDirectory = FileManager.default.homeDirectoryForCurrentUser
|
||||
.appendingPathComponent(".cache")
|
||||
#else
|
||||
guard
|
||||
let first = FileManager.default.urls(
|
||||
for: .cachesDirectory, in: .userDomainMask
|
||||
).first
|
||||
else {
|
||||
throw TTSError.processingFailed("Failed to locate caches directory")
|
||||
}
|
||||
baseDirectory = first
|
||||
#endif
|
||||
|
||||
let cacheDirectory = baseDirectory.appendingPathComponent("fluidaudio")
|
||||
if !FileManager.default.fileExists(atPath: cacheDirectory.path) {
|
||||
try FileManager.default.createDirectory(
|
||||
at: cacheDirectory, withIntermediateDirectories: true)
|
||||
}
|
||||
return cacheDirectory
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
@preconcurrency import CoreML
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// Actor-based store for PocketTTS CoreML models and constants.
|
||||
///
|
||||
/// Manages loading and storing of the four CoreML models
|
||||
/// (cond_step, flowlm_step, flow_decoder, mimi_decoder),
|
||||
/// the binary constants bundle, and voice conditioning data.
|
||||
public actor PocketTtsModelStore {
|
||||
|
||||
private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "PocketTtsModelStore")
|
||||
|
||||
private var condStepModel: MLModel?
|
||||
private var flowlmStepModel: MLModel?
|
||||
private var flowDecoderModel: MLModel?
|
||||
private var mimiDecoderModel: MLModel?
|
||||
private var constantsBundle: PocketTtsConstantsBundle?
|
||||
private var voiceCache: [String: PocketTtsVoiceData] = [:]
|
||||
private var repoDirectory: URL?
|
||||
|
||||
public init() {}
|
||||
|
||||
/// Load all four CoreML models and the constants bundle.
|
||||
public func loadIfNeeded() async throws {
|
||||
guard condStepModel == nil else { return }
|
||||
|
||||
let repoDir = try await PocketTtsResourceDownloader.ensureModels()
|
||||
self.repoDirectory = repoDir
|
||||
|
||||
logger.info("Loading PocketTTS CoreML models...")
|
||||
|
||||
// Use CPU+GPU for all models to avoid ANE float16 precision loss.
|
||||
// The ANE processes in native float16, which causes audible artifacts
|
||||
// in the Mimi decoder's streaming state feedback loop and may degrade
|
||||
// quality in the other models. CPU/GPU compute in float32 matches the
|
||||
// Python reference implementation.
|
||||
let config = MLModelConfiguration()
|
||||
config.computeUnits = .cpuAndGPU
|
||||
|
||||
let loadStart = Date()
|
||||
|
||||
let modelFiles = [
|
||||
ModelNames.PocketTTS.condStepFile,
|
||||
ModelNames.PocketTTS.flowlmStepFile,
|
||||
ModelNames.PocketTTS.flowDecoderFile,
|
||||
ModelNames.PocketTTS.mimiDecoderFile,
|
||||
]
|
||||
|
||||
var loadedModels: [MLModel] = []
|
||||
for file in modelFiles {
|
||||
let modelURL = repoDir.appendingPathComponent(file)
|
||||
let model = try MLModel(contentsOf: modelURL, configuration: config)
|
||||
loadedModels.append(model)
|
||||
logger.info("Loaded \(file)")
|
||||
}
|
||||
|
||||
condStepModel = loadedModels[0]
|
||||
flowlmStepModel = loadedModels[1]
|
||||
flowDecoderModel = loadedModels[2]
|
||||
mimiDecoderModel = loadedModels[3]
|
||||
|
||||
let elapsed = Date().timeIntervalSince(loadStart)
|
||||
logger.info("All PocketTTS models loaded in \(String(format: "%.2f", elapsed))s")
|
||||
|
||||
// Load constants
|
||||
constantsBundle = try PocketTtsResourceDownloader.ensureConstants(
|
||||
repoDirectory: repoDir)
|
||||
logger.info("PocketTTS constants loaded")
|
||||
}
|
||||
|
||||
/// The conditioning step model (KV cache prefill).
|
||||
public func condStep() throws -> MLModel {
|
||||
guard let model = condStepModel else {
|
||||
throw TTSError.modelNotFound("PocketTTS cond_step model not loaded")
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
/// The autoregressive generation step model.
|
||||
public func flowlmStep() throws -> MLModel {
|
||||
guard let model = flowlmStepModel else {
|
||||
throw TTSError.modelNotFound("PocketTTS flowlm_step model not loaded")
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
/// The LSD flow decoder model.
|
||||
public func flowDecoder() throws -> MLModel {
|
||||
guard let model = flowDecoderModel else {
|
||||
throw TTSError.modelNotFound("PocketTTS flow_decoder model not loaded")
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
/// The Mimi streaming audio decoder model.
|
||||
public func mimiDecoder() throws -> MLModel {
|
||||
guard let model = mimiDecoderModel else {
|
||||
throw TTSError.modelNotFound("PocketTTS mimi_decoder model not loaded")
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
/// The pre-loaded binary constants.
|
||||
public func constants() throws -> PocketTtsConstantsBundle {
|
||||
guard let bundle = constantsBundle else {
|
||||
throw TTSError.modelNotFound("PocketTTS constants not loaded")
|
||||
}
|
||||
return bundle
|
||||
}
|
||||
|
||||
/// The repository directory containing models and constants.
|
||||
public func repoDir() throws -> URL {
|
||||
guard let dir = repoDirectory else {
|
||||
throw TTSError.modelNotFound("PocketTTS repository not loaded")
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
/// Load and cache voice conditioning data.
|
||||
public func voiceData(for voice: String) throws -> PocketTtsVoiceData {
|
||||
if let cached = voiceCache[voice] {
|
||||
return cached
|
||||
}
|
||||
guard let repoDir = repoDirectory else {
|
||||
throw TTSError.modelNotFound("PocketTTS repository not loaded")
|
||||
}
|
||||
let data = try PocketTtsResourceDownloader.ensureVoice(voice, repoDirectory: repoDir)
|
||||
voiceCache[voice] = data
|
||||
return data
|
||||
}
|
||||
}
|
||||
+157
@@ -0,0 +1,157 @@
|
||||
@preconcurrency import CoreML
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
extension PocketTtsSynthesizer {
|
||||
|
||||
/// Run the flow decoder using Euler integration (LSD steps).
|
||||
///
|
||||
/// Converts transformer output to a 32-dimensional audio latent
|
||||
/// via `numSteps` iterative denoising steps.
|
||||
static func flowDecode(
|
||||
transformerOut: MLMultiArray,
|
||||
numSteps: Int,
|
||||
temperature: Float,
|
||||
model: MLModel,
|
||||
rng: inout some RandomNumberGenerator
|
||||
) async throws -> [Float] {
|
||||
let latentDim = PocketTtsConstants.latentDim
|
||||
let dt: Float = 1.0 / Float(numSteps)
|
||||
|
||||
// Initialize latent with scaled random noise: randn * sqrt(temperature)
|
||||
var latent = [Float](repeating: 0, count: latentDim)
|
||||
let scale = sqrtf(temperature)
|
||||
for i in 0..<latentDim {
|
||||
latent[i] = Float.gaussianRandom(using: &rng) * scale
|
||||
}
|
||||
|
||||
// Flatten transformer_out from [1, 1, 1024] to [1, 1024]
|
||||
let transformerFlat = try reshapeToFlat(transformerOut, dim: PocketTtsConstants.transformerDim)
|
||||
|
||||
// Euler integration: 8 steps from t=0 to t=1
|
||||
for step in 0..<numSteps {
|
||||
let sValue = Float(step) * dt
|
||||
let tValue = Float(step + 1) * dt
|
||||
|
||||
let velocity = try await runFlowDecoderStep(
|
||||
transformerOut: transformerFlat,
|
||||
latent: latent,
|
||||
s: sValue,
|
||||
t: tValue,
|
||||
model: model
|
||||
)
|
||||
|
||||
// Euler step: latent += velocity * dt
|
||||
for i in 0..<latentDim {
|
||||
latent[i] += velocity[i] * dt
|
||||
}
|
||||
}
|
||||
|
||||
return latent
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
/// Run a single flow decoder step.
|
||||
private static func runFlowDecoderStep(
|
||||
transformerOut: MLMultiArray,
|
||||
latent: [Float],
|
||||
s: Float,
|
||||
t: Float,
|
||||
model: MLModel
|
||||
) async throws -> [Float] {
|
||||
let latentDim = PocketTtsConstants.latentDim
|
||||
|
||||
// Create latent MLMultiArray [1, 32]
|
||||
let latentArray = try MLMultiArray(
|
||||
shape: [1, NSNumber(value: latentDim)], dataType: .float32)
|
||||
let latentPtr = latentArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
|
||||
latent.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
latentPtr.update(from: base, count: latentDim)
|
||||
}
|
||||
|
||||
// Create s and t MLMultiArrays [1, 1]
|
||||
let sArray = try MLMultiArray(shape: [1, 1], dataType: .float32)
|
||||
sArray[0] = NSNumber(value: s)
|
||||
|
||||
let tArray = try MLMultiArray(shape: [1, 1], dataType: .float32)
|
||||
tArray[0] = NSNumber(value: t)
|
||||
|
||||
let inputDict: [String: Any] = [
|
||||
"transformer_out": transformerOut,
|
||||
"latent": latentArray,
|
||||
"s": sArray,
|
||||
"t": tArray,
|
||||
]
|
||||
|
||||
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
|
||||
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
|
||||
|
||||
// Extract velocity — take the first (and likely only) output
|
||||
let outputNames = Array(output.featureNames)
|
||||
guard let velocityArray = output.featureValue(for: outputNames[0])?.multiArrayValue else {
|
||||
throw TTSError.processingFailed("Missing flow decoder velocity output")
|
||||
}
|
||||
|
||||
let velocityPtr = velocityArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
|
||||
return Array(UnsafeBufferPointer(start: velocityPtr, count: latentDim))
|
||||
}
|
||||
|
||||
/// Reshape a [1, 1, D] MLMultiArray to [1, D].
|
||||
private static func reshapeToFlat(_ array: MLMultiArray, dim: Int) throws -> MLMultiArray {
|
||||
let flat = try MLMultiArray(shape: [1, NSNumber(value: dim)], dataType: .float32)
|
||||
let srcPtr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
let dstPtr = flat.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
dstPtr.update(from: srcPtr, count: dim)
|
||||
return flat
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Seeded Random
|
||||
|
||||
/// Simple seeded random number generator (xoshiro256**).
|
||||
///
|
||||
/// Provides reproducible random sequences when a seed is set,
|
||||
/// and falls back to system entropy when unseeded.
|
||||
struct SeededRNG: RandomNumberGenerator {
|
||||
private var state: (UInt64, UInt64, UInt64, UInt64)
|
||||
|
||||
init(seed: UInt64) {
|
||||
// SplitMix64 to expand seed into 4-part state
|
||||
var s = seed
|
||||
func next() -> UInt64 {
|
||||
s &+= 0x9E37_79B9_7F4A_7C15
|
||||
var z = s
|
||||
z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
|
||||
z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
|
||||
return z ^ (z >> 31)
|
||||
}
|
||||
state = (next(), next(), next(), next())
|
||||
}
|
||||
|
||||
mutating func next() -> UInt64 {
|
||||
let result = rotl(state.1 &* 5, 7) &* 9
|
||||
let t = state.1 << 17
|
||||
state.2 ^= state.0
|
||||
state.3 ^= state.1
|
||||
state.1 ^= state.2
|
||||
state.0 ^= state.3
|
||||
state.2 ^= t
|
||||
state.3 = rotl(state.3, 45)
|
||||
return result
|
||||
}
|
||||
|
||||
private func rotl(_ x: UInt64, _ k: Int) -> UInt64 {
|
||||
(x << k) | (x >> (64 - k))
|
||||
}
|
||||
}
|
||||
|
||||
extension Float {
|
||||
/// Generate a single sample from the standard normal distribution (Box-Muller transform).
|
||||
static func gaussianRandom(using rng: inout some RandomNumberGenerator) -> Float {
|
||||
let u1 = Float.random(in: Float.leastNonzeroMagnitude...1.0, using: &rng)
|
||||
let u2 = Float.random(in: 0.0...1.0, using: &rng)
|
||||
return sqrtf(-2.0 * logf(u1)) * cosf(2.0 * .pi * u2)
|
||||
}
|
||||
}
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
@preconcurrency import CoreML
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
extension PocketTtsSynthesizer {
|
||||
|
||||
/// Mutable KV cache state passed through conditioning and generation steps.
|
||||
struct KVCacheState {
|
||||
/// 6 KV cache arrays, each [2, 1, 200, 16, 64].
|
||||
var caches: [MLMultiArray]
|
||||
/// 6 position counters, each [1].
|
||||
var positions: [MLMultiArray]
|
||||
}
|
||||
|
||||
/// Create an empty KV cache state (all zeros, positions at 0).
|
||||
static func emptyKVCacheState() throws -> KVCacheState {
|
||||
let layers = PocketTtsConstants.kvCacheLayers
|
||||
let shape: [NSNumber] = [
|
||||
2, 1, NSNumber(value: PocketTtsConstants.kvCacheMaxLen), 16, 64,
|
||||
]
|
||||
|
||||
var caches: [MLMultiArray] = []
|
||||
var positions: [MLMultiArray] = []
|
||||
caches.reserveCapacity(layers)
|
||||
positions.reserveCapacity(layers)
|
||||
|
||||
for _ in 0..<layers {
|
||||
let cache = try MLMultiArray(shape: shape, dataType: .float32)
|
||||
let cachePtr = cache.dataPointer.bindMemory(
|
||||
to: Float.self, capacity: cache.count)
|
||||
cachePtr.initialize(repeating: 0, count: cache.count)
|
||||
caches.append(cache)
|
||||
|
||||
let pos = try MLMultiArray(shape: [1], dataType: .float32)
|
||||
pos[0] = NSNumber(value: Float(0))
|
||||
positions.append(pos)
|
||||
}
|
||||
|
||||
return KVCacheState(caches: caches, positions: positions)
|
||||
}
|
||||
|
||||
/// Run the conditioning step model for a single token, updating the KV cache in place.
|
||||
static func runCondStep(
|
||||
conditioning: MLMultiArray,
|
||||
state: inout KVCacheState,
|
||||
model: MLModel
|
||||
) async throws {
|
||||
var inputDict: [String: Any] = [
|
||||
"conditioning": conditioning
|
||||
]
|
||||
|
||||
for i in 0..<PocketTtsConstants.kvCacheLayers {
|
||||
inputDict["cache\(i)"] = state.caches[i]
|
||||
inputDict["position\(i)"] = state.positions[i]
|
||||
}
|
||||
|
||||
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
|
||||
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
|
||||
|
||||
for i in 0..<PocketTtsConstants.kvCacheLayers {
|
||||
guard let newCache = output.featureValue(for: CondStepKeys.cacheKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed(
|
||||
"Missing cond_step cache output: \(CondStepKeys.cacheKeys[i])")
|
||||
}
|
||||
guard let newPos = output.featureValue(for: CondStepKeys.positionKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed(
|
||||
"Missing cond_step position output: \(CondStepKeys.positionKeys[i])")
|
||||
}
|
||||
state.caches[i] = newCache
|
||||
state.positions[i] = newPos
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefill the KV cache with voice and text conditioning tokens.
|
||||
///
|
||||
/// Processes voice tokens first, then text tokens (critical ordering).
|
||||
static func prefillKVCache(
|
||||
voiceData: PocketTtsVoiceData,
|
||||
textEmbeddings: [[Float]],
|
||||
model: MLModel
|
||||
) async throws -> KVCacheState {
|
||||
var state = try emptyKVCacheState()
|
||||
let dim = PocketTtsConstants.embeddingDim
|
||||
|
||||
// Voice tokens first (positions 0..124)
|
||||
let voiceTokenCount = voiceData.promptLength
|
||||
for tokenIdx in 0..<voiceTokenCount {
|
||||
let token = try createConditioningToken(
|
||||
from: voiceData.audioPrompt,
|
||||
offset: tokenIdx * dim,
|
||||
dim: dim
|
||||
)
|
||||
try await runCondStep(conditioning: token, state: &state, model: model)
|
||||
}
|
||||
|
||||
// Text tokens next
|
||||
for embedding in textEmbeddings {
|
||||
let token = try createConditioningToken(from: embedding, offset: 0, dim: dim)
|
||||
try await runCondStep(conditioning: token, state: &state, model: model)
|
||||
}
|
||||
|
||||
let finalPos = state.positions[0][0].floatValue
|
||||
logger.info("KV cache prefilled to position \(Int(finalPos))")
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
/// Create a [1, 1, 1024] MLMultiArray from a float slice.
|
||||
private static func createConditioningToken(
|
||||
from source: [Float], offset: Int, dim: Int
|
||||
) throws -> MLMultiArray {
|
||||
let array = try MLMultiArray(
|
||||
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
source.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
ptr.update(from: base.advanced(by: offset), count: dim)
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
/// Run the generation step model, returning transformer output and EOS logit.
|
||||
static func runFlowLMStep(
|
||||
sequence: MLMultiArray,
|
||||
bosEmb: MLMultiArray,
|
||||
state: inout KVCacheState,
|
||||
model: MLModel
|
||||
) async throws -> (transformerOut: MLMultiArray, eosLogit: Float) {
|
||||
var inputDict: [String: Any] = [
|
||||
"sequence": sequence,
|
||||
"bos_emb": bosEmb,
|
||||
]
|
||||
|
||||
for i in 0..<PocketTtsConstants.kvCacheLayers {
|
||||
inputDict["cache\(i)"] = state.caches[i]
|
||||
inputDict["position\(i)"] = state.positions[i]
|
||||
}
|
||||
|
||||
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
|
||||
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
|
||||
|
||||
// Extract transformer output
|
||||
guard let transformerOut = output.featureValue(for: FlowLMStepKeys.transformerOut)?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed("Missing flowlm_step transformer output")
|
||||
}
|
||||
|
||||
// Extract EOS logit
|
||||
guard let eosArray = output.featureValue(for: FlowLMStepKeys.eosLogit)?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed("Missing flowlm_step EOS logit")
|
||||
}
|
||||
let eosLogit = eosArray[0].floatValue
|
||||
|
||||
// Update caches and positions
|
||||
for i in 0..<PocketTtsConstants.kvCacheLayers {
|
||||
guard
|
||||
let newCache = output.featureValue(for: FlowLMStepKeys.cacheKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed(
|
||||
"Missing flowlm_step cache output: \(FlowLMStepKeys.cacheKeys[i])")
|
||||
}
|
||||
guard let newPos = output.featureValue(for: FlowLMStepKeys.positionKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw TTSError.processingFailed(
|
||||
"Missing flowlm_step position output: \(FlowLMStepKeys.positionKeys[i])")
|
||||
}
|
||||
state.caches[i] = newCache
|
||||
state.positions[i] = newPos
|
||||
}
|
||||
|
||||
return (transformerOut: transformerOut, eosLogit: eosLogit)
|
||||
}
|
||||
}
|
||||
+159
@@ -0,0 +1,159 @@
|
||||
@preconcurrency import CoreML
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
extension PocketTtsSynthesizer {
|
||||
|
||||
/// Mutable streaming state for the Mimi audio decoder.
|
||||
///
|
||||
/// Contains 26 tensors that track convolutional history,
|
||||
/// attention caches, and partial upsampling buffers.
|
||||
struct MimiState {
|
||||
var tensors: [String: MLMultiArray]
|
||||
}
|
||||
|
||||
/// Create the initial Mimi decoder state from the constants directory.
|
||||
///
|
||||
/// Loads pre-computed initial state tensors from `.bin` files,
|
||||
/// using `manifest.json` for shape metadata.
|
||||
static func loadMimiInitialState(from repoDirectory: URL) throws -> MimiState {
|
||||
let constantsDir = repoDirectory.appendingPathComponent(ModelNames.PocketTTS.constantsBinDir)
|
||||
let stateDir = constantsDir.appendingPathComponent("mimi_init_state")
|
||||
let manifestURL = constantsDir.appendingPathComponent("manifest.json")
|
||||
|
||||
// Parse manifest for mimi_init_state shapes
|
||||
let manifestData = try Data(contentsOf: manifestURL)
|
||||
guard let manifest = try JSONSerialization.jsonObject(with: manifestData) as? [String: Any],
|
||||
let mimiManifest = manifest["mimi_init_state"] as? [String: Any]
|
||||
else {
|
||||
throw TTSError.processingFailed("Failed to parse mimi_init_state from manifest.json")
|
||||
}
|
||||
|
||||
var tensors: [String: MLMultiArray] = [:]
|
||||
|
||||
for (name, info) in mimiManifest {
|
||||
guard let infoDict = info as? [String: Any],
|
||||
let shapeArray = infoDict["shape"] as? [Int],
|
||||
let byteCount = infoDict["bytes"] as? Int
|
||||
else {
|
||||
continue
|
||||
}
|
||||
|
||||
let shape = shapeArray.map { NSNumber(value: $0) }
|
||||
let array = try MLMultiArray(shape: shape, dataType: .float32)
|
||||
|
||||
if byteCount > 0 && !shapeArray.contains(0) {
|
||||
let binURL = stateDir.appendingPathComponent("\(name).bin")
|
||||
let data = try Data(contentsOf: binURL)
|
||||
let floatCount = byteCount / MemoryLayout<Float>.size
|
||||
let dstPtr = array.dataPointer.bindMemory(to: Float.self, capacity: floatCount)
|
||||
data.withUnsafeBytes { rawBuffer in
|
||||
let srcPtr = rawBuffer.bindMemory(to: Float.self)
|
||||
dstPtr.update(from: srcPtr.baseAddress!, count: floatCount)
|
||||
}
|
||||
}
|
||||
|
||||
tensors[name] = array
|
||||
}
|
||||
|
||||
// Ensure offset scalars exist
|
||||
for key in ["attn0_offset", "attn0_end_offset", "attn1_offset", "attn1_end_offset"] {
|
||||
if tensors[key] == nil {
|
||||
let scalar = try MLMultiArray(shape: [1], dataType: .float32)
|
||||
scalar[0] = NSNumber(value: Float(0))
|
||||
tensors[key] = scalar
|
||||
}
|
||||
}
|
||||
|
||||
return MimiState(tensors: tensors)
|
||||
}
|
||||
|
||||
/// Clone a Mimi state for independent use.
|
||||
static func cloneMimiState(_ state: MimiState) throws -> MimiState {
|
||||
var newTensors: [String: MLMultiArray] = [:]
|
||||
for (key, array) in state.tensors {
|
||||
let copy = try MLMultiArray(shape: array.shape, dataType: array.dataType)
|
||||
let byteSize: Int
|
||||
switch array.dataType {
|
||||
case .float16:
|
||||
byteSize = array.count * MemoryLayout<UInt16>.size
|
||||
default:
|
||||
byteSize = array.count * MemoryLayout<Float>.size
|
||||
}
|
||||
if byteSize > 0 {
|
||||
copy.dataPointer.copyMemory(from: array.dataPointer, byteCount: byteSize)
|
||||
}
|
||||
newTensors[key] = copy
|
||||
}
|
||||
return MimiState(tensors: newTensors)
|
||||
}
|
||||
|
||||
/// Run the Mimi decoder for a single latent frame.
|
||||
///
|
||||
/// The model internally denormalizes and quantizes the 32-dim latent
|
||||
/// before decoding to audio.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - latent: The raw latent vector, shape [32].
|
||||
/// - state: The streaming state (26 tensors), modified in place.
|
||||
/// - model: The Mimi CoreML model.
|
||||
/// - Returns: Audio samples for this frame (1920 samples = 80ms at 24kHz).
|
||||
static func runMimiDecoder(
|
||||
latent: [Float],
|
||||
state: inout MimiState,
|
||||
model: MLModel
|
||||
) async throws -> [Float] {
|
||||
// Create latent input: [1, 32]
|
||||
let latentDim = PocketTtsConstants.latentDim
|
||||
let latentArray = try MLMultiArray(
|
||||
shape: [1, NSNumber(value: latentDim)], dataType: .float32)
|
||||
let latentPtr = latentArray.dataPointer.bindMemory(to: Float.self, capacity: latentDim)
|
||||
latent.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
latentPtr.update(from: base, count: latentDim)
|
||||
}
|
||||
|
||||
// Build input dictionary
|
||||
var inputDict: [String: Any] = ["latent": latentArray]
|
||||
for (key, array) in state.tensors {
|
||||
inputDict[key] = array
|
||||
}
|
||||
|
||||
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
|
||||
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
|
||||
|
||||
// Extract audio output [1, 1, 1920]
|
||||
guard let audioArray = output.featureValue(for: MimiKeys.audioOutput)?.multiArrayValue else {
|
||||
throw TTSError.processingFailed("Missing Mimi audio output")
|
||||
}
|
||||
|
||||
let sampleCount = PocketTtsConstants.samplesPerFrame
|
||||
let samples = readFloatArray(from: audioArray, count: sampleCount)
|
||||
|
||||
// Update streaming state
|
||||
for (inputName, outputName) in mimiStateMapping {
|
||||
guard let updated = output.featureValue(for: outputName)?.multiArrayValue else {
|
||||
throw TTSError.processingFailed(
|
||||
"Missing Mimi state output: \(outputName) (for \(inputName))")
|
||||
}
|
||||
state.tensors[inputName] = updated
|
||||
}
|
||||
|
||||
return samples
|
||||
}
|
||||
|
||||
/// Read Float values from an MLMultiArray, handling both float32 and float16 data types.
|
||||
///
|
||||
/// The Mimi decoder CoreML model outputs float16 tensors. Using `dataPointer` with
|
||||
/// `Float.self` binding on float16 data produces garbage values. This method
|
||||
/// uses the type-safe subscript accessor which handles conversion automatically.
|
||||
private static func readFloatArray(from array: MLMultiArray, count: Int) -> [Float] {
|
||||
if array.dataType == .float16 {
|
||||
// Use subscript for correct float16 → float32 conversion
|
||||
return (0..<count).map { array[$0].floatValue }
|
||||
}
|
||||
// Fast path for float32: direct memory access
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: count)
|
||||
return Array(UnsafeBufferPointer(start: ptr, count: count))
|
||||
}
|
||||
}
|
||||
+88
@@ -0,0 +1,88 @@
|
||||
import Foundation
|
||||
|
||||
extension PocketTtsSynthesizer {
|
||||
|
||||
/// Result of a PocketTTS synthesis operation.
|
||||
public struct SynthesisResult: Sendable {
|
||||
/// WAV audio data (24kHz, 16-bit mono).
|
||||
public let audio: Data
|
||||
/// Raw Float32 audio samples.
|
||||
public let samples: [Float]
|
||||
/// Number of 80ms frames generated.
|
||||
public let frameCount: Int
|
||||
/// Generation step at which EOS was detected (nil if max length reached).
|
||||
public let eosStep: Int?
|
||||
}
|
||||
|
||||
/// CoreML output key names for the conditioning step model.
|
||||
enum CondStepKeys {
|
||||
static let cacheKeys: [String] = [
|
||||
"new_cache_1_internal_tensor_assign_2",
|
||||
"new_cache_3_internal_tensor_assign_2",
|
||||
"new_cache_5_internal_tensor_assign_2",
|
||||
"new_cache_7_internal_tensor_assign_2",
|
||||
"new_cache_9_internal_tensor_assign_2",
|
||||
"new_cache_internal_tensor_assign_2",
|
||||
]
|
||||
static let positionKeys: [String] = [
|
||||
"var_445", "var_864", "var_1283", "var_1702", "var_2121", "var_2365",
|
||||
]
|
||||
}
|
||||
|
||||
/// CoreML output key names for the generation step model.
|
||||
enum FlowLMStepKeys {
|
||||
/// CoreML assigned this output the name "input" during model tracing —
|
||||
/// it is the transformer hidden state output, not an input tensor.
|
||||
static let transformerOut = "input"
|
||||
static let eosLogit = "var_2582"
|
||||
static let cacheKeys: [String] = [
|
||||
"new_cache_1_internal_tensor_assign_2",
|
||||
"new_cache_3_internal_tensor_assign_2",
|
||||
"new_cache_5_internal_tensor_assign_2",
|
||||
"new_cache_7_internal_tensor_assign_2",
|
||||
"new_cache_9_internal_tensor_assign_2",
|
||||
"new_cache_internal_tensor_assign_2",
|
||||
]
|
||||
static let positionKeys: [String] = [
|
||||
"var_458", "var_877", "var_1296", "var_1715", "var_2134", "var_2553",
|
||||
]
|
||||
}
|
||||
|
||||
/// CoreML output key names for the Mimi decoder model.
|
||||
enum MimiKeys {
|
||||
static let audioOutput = "var_821"
|
||||
}
|
||||
|
||||
/// Mimi decoder streaming state key mappings (input name → output name).
|
||||
///
|
||||
/// 26 state tensors including 3 zero-length tensors (res{0,1,2}_conv1_prev)
|
||||
/// whose input and output names are identical pass-throughs.
|
||||
static let mimiStateMapping: [(input: String, output: String)] = [
|
||||
("upsample_partial", "var_82"),
|
||||
("attn0_cache", "var_262"),
|
||||
("attn0_offset", "var_840"),
|
||||
("attn0_end_offset", "new_end_offset_1"),
|
||||
("attn1_cache", "var_479"),
|
||||
("attn1_offset", "var_843"),
|
||||
("attn1_end_offset", "new_end_offset"),
|
||||
("conv0_prev", "var_607"),
|
||||
("conv0_first", "conv0_first"),
|
||||
("convtr0_partial", "var_634"),
|
||||
("res0_conv0_prev", "var_660"),
|
||||
("res0_conv0_first", "res0_conv0_first"),
|
||||
("res0_conv1_prev", "res0_conv1_prev"),
|
||||
("res0_conv1_first", "res0_conv1_first"),
|
||||
("convtr1_partial", "var_700"),
|
||||
("res1_conv0_prev", "var_726"),
|
||||
("res1_conv0_first", "res1_conv0_first"),
|
||||
("res1_conv1_prev", "res1_conv1_prev"),
|
||||
("res1_conv1_first", "res1_conv1_first"),
|
||||
("convtr2_partial", "var_766"),
|
||||
("res2_conv0_prev", "var_792"),
|
||||
("res2_conv0_first", "res2_conv0_first"),
|
||||
("res2_conv1_prev", "res2_conv1_prev"),
|
||||
("res2_conv1_first", "res2_conv1_first"),
|
||||
("conv_final_prev", "var_824"),
|
||||
("conv_final_first", "conv_final_first"),
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,535 @@
|
||||
@preconcurrency import CoreML
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// PocketTTS flow-matching language model synthesizer.
|
||||
///
|
||||
/// Generates audio autoregressively: each generation step produces
|
||||
/// an 80ms audio frame (1920 samples at 24kHz).
|
||||
///
|
||||
/// Long text is split into sentence-based chunks (≤50 tokens each)
|
||||
/// to stay within the KV cache limit (200 positions).
|
||||
///
|
||||
/// Pipeline: text → chunk → [tokenize → embed → prefill KV → generate → flow decode → mimi decode] → WAV
|
||||
public struct PocketTtsSynthesizer {
|
||||
|
||||
static let logger = AppLogger(category: "PocketTtsSynthesizer")
|
||||
|
||||
private enum Context {
|
||||
@TaskLocal static var modelStore: PocketTtsModelStore?
|
||||
}
|
||||
|
||||
static func withModelStore<T>(
|
||||
_ store: PocketTtsModelStore,
|
||||
operation: () async throws -> T
|
||||
) async rethrows -> T {
|
||||
try await Context.$modelStore.withValue(store) {
|
||||
try await operation()
|
||||
}
|
||||
}
|
||||
|
||||
static func currentModelStore() throws -> PocketTtsModelStore {
|
||||
guard let store = Context.modelStore else {
|
||||
throw TTSError.processingFailed(
|
||||
"PocketTtsSynthesizer requires a model store context.")
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// MARK: - Public API
|
||||
|
||||
/// Synthesize audio from text.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - text: The text to synthesize.
|
||||
/// - voice: Voice identifier (default: "alba").
|
||||
/// - temperature: Generation temperature (default: 0.7).
|
||||
/// - seed: Random seed for reproducibility (nil for random).
|
||||
/// - deEss: Whether to apply de-essing post-processing.
|
||||
/// - Returns: A synthesis result containing WAV audio data.
|
||||
public static func synthesize(
|
||||
text: String,
|
||||
voice: String = PocketTtsConstants.defaultVoice,
|
||||
temperature: Float = PocketTtsConstants.temperature,
|
||||
seed: UInt64? = nil,
|
||||
deEss: Bool = true
|
||||
) async throws -> SynthesisResult {
|
||||
let store = try currentModelStore()
|
||||
|
||||
logger.info("PocketTTS synthesizing: '\(text)'")
|
||||
|
||||
// 1. Load constants and voice
|
||||
let constants = try await store.constants()
|
||||
let voiceData = try await store.voiceData(for: voice)
|
||||
|
||||
// 2. Split text into chunks that fit within KV cache capacity
|
||||
let chunks = chunkText(text, tokenizer: constants.tokenizer)
|
||||
logger.info("Split into \(chunks.count) chunk(s)")
|
||||
|
||||
// 3. Set up random number generator (seeded or system entropy)
|
||||
var rng = SeededRNG(seed: seed ?? UInt64.random(in: 0...UInt64.max))
|
||||
|
||||
// 4. Load models
|
||||
let condModel = try await store.condStep()
|
||||
let stepModel = try await store.flowlmStep()
|
||||
let flowModel = try await store.flowDecoder()
|
||||
let mimiModel = try await store.mimiDecoder()
|
||||
|
||||
// 5. Load Mimi initial state (continuous across chunks)
|
||||
let repoDir = try await store.repoDir()
|
||||
var mimiState = try loadMimiInitialState(from: repoDir)
|
||||
|
||||
// 6. Create BOS embedding
|
||||
let bosEmb = try createBosEmbedding(constants.bosEmbedding)
|
||||
|
||||
// 7. Generate audio for each chunk
|
||||
var audioChunks: [[Float]] = []
|
||||
var lastEosStep: Int?
|
||||
|
||||
let genStart = Date()
|
||||
|
||||
for (chunkIdx, chunkText) in chunks.enumerated() {
|
||||
let (normalizedChunk, framesAfterEos) = normalizeText(chunkText)
|
||||
logger.info("Chunk \(chunkIdx + 1)/\(chunks.count): '\(normalizedChunk)'")
|
||||
|
||||
// Tokenize and embed this chunk
|
||||
let tokenIds = constants.tokenizer.encode(normalizedChunk)
|
||||
let textEmbeddings = embedTokens(tokenIds, constants: constants)
|
||||
|
||||
// Fresh KV cache per chunk
|
||||
let prefillStart = Date()
|
||||
var kvState = try await prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel
|
||||
)
|
||||
let prefillElapsed = Date().timeIntervalSince(prefillStart)
|
||||
logger.info(
|
||||
"Chunk \(chunkIdx + 1) prefill: \(String(format: "%.2f", prefillElapsed))s (\(tokenIds.count) tokens)"
|
||||
)
|
||||
|
||||
// Generation loop for this chunk
|
||||
let maxGenLen = estimateMaxFrames(text: chunkText)
|
||||
var eosStep: Int?
|
||||
var sequence = try createNaNSequence()
|
||||
let totalFramesAfterEos =
|
||||
framesAfterEos + PocketTtsConstants.extraFramesAfterDetection
|
||||
|
||||
for step in 0..<maxGenLen {
|
||||
let (transformerOut, eosLogit) = try await runFlowLMStep(
|
||||
sequence: sequence,
|
||||
bosEmb: bosEmb,
|
||||
state: &kvState,
|
||||
model: stepModel
|
||||
)
|
||||
|
||||
if eosLogit > PocketTtsConstants.eosThreshold && eosStep == nil {
|
||||
eosStep = step
|
||||
logger.info("Chunk \(chunkIdx + 1) EOS at step \(step)")
|
||||
}
|
||||
if let eos = eosStep, step >= eos + totalFramesAfterEos {
|
||||
break
|
||||
}
|
||||
|
||||
let latent = try await flowDecode(
|
||||
transformerOut: transformerOut,
|
||||
numSteps: PocketTtsConstants.numLsdSteps,
|
||||
temperature: temperature,
|
||||
model: flowModel,
|
||||
rng: &rng
|
||||
)
|
||||
|
||||
// Mimi state is continuous across chunks
|
||||
// (denormalize + quantize baked into mimi_decoder model)
|
||||
let frameSamples = try await runMimiDecoder(
|
||||
latent: latent,
|
||||
state: &mimiState,
|
||||
model: mimiModel
|
||||
)
|
||||
audioChunks.append(frameSamples)
|
||||
|
||||
sequence = try createSequenceFromLatent(latent)
|
||||
|
||||
if step % 20 == 0 {
|
||||
logger.info("Chunk \(chunkIdx + 1) step \(step)...")
|
||||
}
|
||||
}
|
||||
|
||||
lastEosStep = eosStep
|
||||
}
|
||||
|
||||
let genElapsed = Date().timeIntervalSince(genStart)
|
||||
logger.info(
|
||||
"Generated \(audioChunks.count) frames in \(String(format: "%.2f", genElapsed))s")
|
||||
|
||||
// 8. Concatenate audio (no peak normalization — preserve natural levels)
|
||||
var allSamples = audioChunks.flatMap { $0 }
|
||||
|
||||
// De-essing
|
||||
if deEss {
|
||||
AudioPostProcessor.applyTtsPostProcessing(
|
||||
&allSamples,
|
||||
sampleRate: Float(PocketTtsConstants.audioSampleRate),
|
||||
deEssAmount: -3.0,
|
||||
smoothing: false
|
||||
)
|
||||
}
|
||||
|
||||
// 9. Encode WAV
|
||||
let audioData = try AudioWAV.data(
|
||||
from: allSamples,
|
||||
sampleRate: Double(PocketTtsConstants.audioSampleRate)
|
||||
)
|
||||
|
||||
let duration = Double(allSamples.count) / Double(PocketTtsConstants.audioSampleRate)
|
||||
logger.info("Audio duration: \(String(format: "%.2f", duration))s")
|
||||
|
||||
return SynthesisResult(
|
||||
audio: audioData,
|
||||
samples: allSamples,
|
||||
frameCount: audioChunks.count,
|
||||
eosStep: lastEosStep
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Text Processing
|
||||
|
||||
/// Normalize a text chunk for PocketTTS (matching Python `prepare_text_prompt`).
|
||||
static func normalizeText(_ text: String) -> (text: String, framesAfterEos: Int) {
|
||||
var result = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
// Collapse whitespace
|
||||
result = result.replacingOccurrences(
|
||||
of: "\\s+", with: " ", options: .regularExpression)
|
||||
|
||||
// Strip trailing clause punctuation (commas, semicolons, colons)
|
||||
// before adding sentence-ending punctuation
|
||||
while let last = result.last, ",;:".contains(last) {
|
||||
result = String(result.dropLast())
|
||||
}
|
||||
result = result.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
// Capitalize first letter
|
||||
if let first = result.first, first.isLetter {
|
||||
result = first.uppercased() + result.dropFirst()
|
||||
}
|
||||
|
||||
// Add period if no terminal punctuation
|
||||
if let last = result.last, !".!?".contains(last) {
|
||||
result += "."
|
||||
}
|
||||
|
||||
// Pad short texts for better prosody
|
||||
let wordCount = result.split(separator: " ").count
|
||||
let framesAfterEos: Int
|
||||
if wordCount < PocketTtsConstants.shortTextWordThreshold {
|
||||
result = String(repeating: " ", count: 8) + result
|
||||
framesAfterEos = PocketTtsConstants.shortTextPadFrames
|
||||
} else {
|
||||
framesAfterEos = PocketTtsConstants.longTextExtraFrames
|
||||
}
|
||||
|
||||
return (result, framesAfterEos)
|
||||
}
|
||||
|
||||
/// Split text into chunks that fit within the KV cache token limit.
|
||||
///
|
||||
/// Splits at sentence boundaries (`.!?`) and groups sentences into chunks
|
||||
/// where each chunk tokenizes to ≤ `maxTokensPerChunk` tokens.
|
||||
/// Oversized single sentences are further split at word boundaries.
|
||||
static func chunkText(
|
||||
_ text: String,
|
||||
tokenizer: SentencePieceTokenizer,
|
||||
maxTokens: Int = PocketTtsConstants.maxTokensPerChunk
|
||||
) -> [String] {
|
||||
let normalized = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
// If it fits in one chunk, return as-is
|
||||
let tokenCount = tokenizer.encode(normalized).count
|
||||
if tokenCount <= maxTokens {
|
||||
return [normalized]
|
||||
}
|
||||
|
||||
// Split into sentences at .!? boundaries
|
||||
let sentences = splitSentences(normalized)
|
||||
|
||||
// Further split any oversized sentences at word boundaries
|
||||
var pieces: [String] = []
|
||||
for sentence in sentences {
|
||||
let sentenceTokens = tokenizer.encode(sentence).count
|
||||
if sentenceTokens <= maxTokens {
|
||||
pieces.append(sentence)
|
||||
} else {
|
||||
pieces.append(contentsOf: splitOversizedSentence(sentence, tokenizer: tokenizer, maxTokens: maxTokens))
|
||||
}
|
||||
}
|
||||
|
||||
// Group pieces into chunks that fit the token limit
|
||||
var chunks: [String] = []
|
||||
var currentChunk = ""
|
||||
|
||||
for piece in pieces {
|
||||
let candidate: String
|
||||
if currentChunk.isEmpty {
|
||||
candidate = piece
|
||||
} else {
|
||||
candidate = currentChunk + " " + piece
|
||||
}
|
||||
|
||||
let candidateTokens = tokenizer.encode(candidate).count
|
||||
if candidateTokens <= maxTokens {
|
||||
currentChunk = candidate
|
||||
} else {
|
||||
if !currentChunk.isEmpty {
|
||||
chunks.append(currentChunk)
|
||||
}
|
||||
currentChunk = piece
|
||||
}
|
||||
}
|
||||
|
||||
if !currentChunk.isEmpty {
|
||||
chunks.append(currentChunk)
|
||||
}
|
||||
|
||||
return chunks.isEmpty ? [normalized] : chunks
|
||||
}
|
||||
|
||||
/// Split an oversized sentence to fit within the token limit.
|
||||
///
|
||||
/// First tries splitting at clause boundaries (commas, semicolons, colons).
|
||||
/// Falls back to word-boundary splitting for clauses that still exceed the limit.
|
||||
private static func splitOversizedSentence(
|
||||
_ text: String,
|
||||
tokenizer: SentencePieceTokenizer,
|
||||
maxTokens: Int
|
||||
) -> [String] {
|
||||
// First try: split at clause boundaries
|
||||
let clauseParts = splitAtClauseBoundaries(text)
|
||||
|
||||
// Group clause parts into chunks that fit
|
||||
var result: [String] = []
|
||||
var currentPart = ""
|
||||
|
||||
for part in clauseParts {
|
||||
let candidate = currentPart.isEmpty ? part : currentPart + " " + part
|
||||
let candidateTokens = tokenizer.encode(candidate).count
|
||||
|
||||
if candidateTokens <= maxTokens {
|
||||
currentPart = candidate
|
||||
} else {
|
||||
if !currentPart.isEmpty {
|
||||
result.append(currentPart)
|
||||
}
|
||||
// If single clause part still exceeds limit, split at word boundaries
|
||||
if tokenizer.encode(part).count > maxTokens {
|
||||
result.append(contentsOf: splitAtWordBoundaries(part, tokenizer: tokenizer, maxTokens: maxTokens))
|
||||
currentPart = ""
|
||||
} else {
|
||||
currentPart = part
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !currentPart.isEmpty {
|
||||
result.append(currentPart)
|
||||
}
|
||||
|
||||
return result.isEmpty ? [text] : result
|
||||
}
|
||||
|
||||
/// Split text at clause punctuation (commas, semicolons, colons).
|
||||
///
|
||||
/// Does not split at commas within numbers (e.g., "3,500").
|
||||
private static func splitAtClauseBoundaries(_ text: String) -> [String] {
|
||||
let clauseBreaks: Set<Character> = [",", ";", ":"]
|
||||
var parts: [String] = []
|
||||
var current = ""
|
||||
let chars = Array(text)
|
||||
|
||||
for (i, char) in chars.enumerated() {
|
||||
current.append(char)
|
||||
|
||||
guard clauseBreaks.contains(char) else { continue }
|
||||
|
||||
// Don't split at commas between digits (e.g., "3,500")
|
||||
if char == "," {
|
||||
let prevIsDigit = i > 0 && chars[i - 1].isNumber
|
||||
let nextIsDigit = i + 1 < chars.count && chars[i + 1].isNumber
|
||||
if prevIsDigit && nextIsDigit {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
let trimmed = current.trimmingCharacters(in: .whitespaces)
|
||||
if !trimmed.isEmpty {
|
||||
parts.append(trimmed)
|
||||
}
|
||||
current = ""
|
||||
}
|
||||
|
||||
let trimmed = current.trimmingCharacters(in: .whitespaces)
|
||||
if !trimmed.isEmpty {
|
||||
parts.append(trimmed)
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
/// Split text at word boundaries to fit within the token limit.
|
||||
private static func splitAtWordBoundaries(
|
||||
_ text: String,
|
||||
tokenizer: SentencePieceTokenizer,
|
||||
maxTokens: Int
|
||||
) -> [String] {
|
||||
let words = text.split(separator: " ").map(String.init)
|
||||
guard words.count > 1 else { return [text] }
|
||||
|
||||
var chunks: [String] = []
|
||||
var currentWords: [String] = []
|
||||
|
||||
for word in words {
|
||||
let candidate = (currentWords + [word]).joined(separator: " ")
|
||||
let tokens = tokenizer.encode(candidate).count
|
||||
|
||||
if tokens > maxTokens && !currentWords.isEmpty {
|
||||
chunks.append(currentWords.joined(separator: " "))
|
||||
currentWords = [word]
|
||||
} else {
|
||||
currentWords.append(word)
|
||||
}
|
||||
}
|
||||
|
||||
if !currentWords.isEmpty {
|
||||
chunks.append(currentWords.joined(separator: " "))
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
/// Common abbreviations that end with a period but don't end a sentence.
|
||||
private static let abbreviations: Set<String> = [
|
||||
"dr", "mr", "mrs", "ms", "prof", "sr", "jr", "st", "vs", "etc",
|
||||
"inc", "ltd", "co", "corp", "dept", "univ", "govt", "approx",
|
||||
"avg", "est", "gen", "gov", "hon", "sgt", "cpl", "pvt", "capt",
|
||||
"lt", "col", "maj", "cmdr", "adm", "rev", "sen", "rep",
|
||||
]
|
||||
|
||||
/// Split text into sentences at `.!?` boundaries.
|
||||
///
|
||||
/// Handles abbreviations (e.g., "Dr.", "Prof.") by not splitting after them.
|
||||
private static func splitSentences(_ text: String) -> [String] {
|
||||
var sentences: [String] = []
|
||||
var current = ""
|
||||
let chars = Array(text)
|
||||
|
||||
for (i, char) in chars.enumerated() {
|
||||
current.append(char)
|
||||
|
||||
guard ".!?".contains(char) else { continue }
|
||||
|
||||
// For periods, check if this is an abbreviation
|
||||
if char == "." {
|
||||
let trimmed = current.trimmingCharacters(in: .whitespaces)
|
||||
// Get the last word before the period
|
||||
let withoutPeriod = String(trimmed.dropLast())
|
||||
let lastWord = withoutPeriod.split(separator: " ").last.map(String.init) ?? withoutPeriod
|
||||
|
||||
// Skip if it's a known abbreviation
|
||||
if abbreviations.contains(lastWord.lowercased()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if it's a single uppercase letter (e.g., "J." in initials)
|
||||
if lastWord.count == 1, lastWord.first?.isUppercase == true {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if followed by a digit (e.g., "3.5")
|
||||
if i + 1 < chars.count, chars[i + 1].isNumber {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
let trimmed = current.trimmingCharacters(in: .whitespaces)
|
||||
if !trimmed.isEmpty {
|
||||
sentences.append(trimmed)
|
||||
}
|
||||
current = ""
|
||||
}
|
||||
|
||||
// Remaining text without terminal punctuation
|
||||
let trimmed = current.trimmingCharacters(in: .whitespaces)
|
||||
if !trimmed.isEmpty {
|
||||
sentences.append(trimmed)
|
||||
}
|
||||
|
||||
return sentences
|
||||
}
|
||||
|
||||
// MARK: - Embedding
|
||||
|
||||
/// Look up text token embeddings from the embedding table.
|
||||
static func embedTokens(
|
||||
_ tokenIds: [Int], constants: PocketTtsConstantsBundle
|
||||
) -> [[Float]] {
|
||||
let dim = PocketTtsConstants.embeddingDim
|
||||
let vocabSize = PocketTtsConstants.vocabSize
|
||||
return tokenIds.map { id in
|
||||
guard id >= 0, id < vocabSize else {
|
||||
logger.warning("Token ID \(id) out of range [0, \(vocabSize)), clamping")
|
||||
let clampedId = min(max(id, 0), vocabSize - 1)
|
||||
let offset = clampedId * dim
|
||||
return Array(constants.textEmbedTable[offset..<(offset + dim)])
|
||||
}
|
||||
let offset = id * dim
|
||||
return Array(constants.textEmbedTable[offset..<(offset + dim)])
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
/// Estimate maximum generation frames based on text length.
|
||||
private static func estimateMaxFrames(text: String) -> Int {
|
||||
let wordCount = text.split(separator: " ").count
|
||||
let genLenSec = Double(wordCount) + 2.0
|
||||
return Int(genLenSec * 12.5)
|
||||
}
|
||||
|
||||
/// Create the BOS embedding as an MLMultiArray [32].
|
||||
private static func createBosEmbedding(_ bos: [Float]) throws -> MLMultiArray {
|
||||
let dim = PocketTtsConstants.latentDim
|
||||
let array = try MLMultiArray(shape: [NSNumber(value: dim)], dataType: .float32)
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
bos.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
ptr.update(from: base, count: dim)
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
/// Create a NaN-filled sequence [1, 1, 32] (signals BOS to the model).
|
||||
private static func createNaNSequence() throws -> MLMultiArray {
|
||||
let dim = PocketTtsConstants.latentDim
|
||||
let array = try MLMultiArray(
|
||||
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
for i in 0..<dim {
|
||||
ptr[i] = .nan
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
/// Create a sequence [1, 1, 32] from a latent vector.
|
||||
private static func createSequenceFromLatent(_ latent: [Float]) throws -> MLMultiArray {
|
||||
let dim = PocketTtsConstants.latentDim
|
||||
let array = try MLMultiArray(
|
||||
shape: [1, 1, NSNumber(value: dim)], dataType: .float32)
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: dim)
|
||||
latent.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
ptr.update(from: base, count: dim)
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
import Foundation
|
||||
|
||||
/// Constants for the PocketTTS flow-matching language model TTS backend.
|
||||
public enum PocketTtsConstants {
|
||||
|
||||
// MARK: - Audio
|
||||
|
||||
public static let audioSampleRate: Int = 24_000
|
||||
public static let samplesPerFrame: Int = 1_920
|
||||
|
||||
// MARK: - Model dimensions
|
||||
|
||||
public static let latentDim: Int = 32
|
||||
public static let transformerDim: Int = 1024
|
||||
public static let vocabSize: Int = 4001
|
||||
public static let embeddingDim: Int = 1024
|
||||
|
||||
// MARK: - Generation parameters
|
||||
|
||||
public static let numLsdSteps: Int = 8
|
||||
public static let temperature: Float = 0.7
|
||||
public static let eosThreshold: Float = -4.0
|
||||
public static let shortTextPadFrames: Int = 3
|
||||
public static let longTextExtraFrames: Int = 1
|
||||
public static let extraFramesAfterDetection: Int = 2
|
||||
public static let shortTextWordThreshold: Int = 5
|
||||
public static let maxTokensPerChunk: Int = 50
|
||||
|
||||
// MARK: - KV cache
|
||||
|
||||
public static let kvCacheLayers: Int = 6
|
||||
public static let kvCacheMaxLen: Int = 512
|
||||
|
||||
// MARK: - Voice
|
||||
|
||||
public static let defaultVoice: String = "alba"
|
||||
public static let voicePromptLength: Int = 125
|
||||
|
||||
// MARK: - Repository
|
||||
|
||||
public static let defaultModelsSubdirectory: String = "Models"
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
import OSLog
|
||||
|
||||
/// Manages text-to-speech synthesis using PocketTTS CoreML models.
|
||||
///
|
||||
/// PocketTTS uses a flow-matching language model architecture that generates
|
||||
/// audio autoregressively at 24kHz. Each generation step produces an 80ms
|
||||
/// audio frame (1920 samples).
|
||||
///
|
||||
/// Example usage:
|
||||
/// ```swift
|
||||
/// let manager = PocketTtsManager()
|
||||
/// try await manager.initialize()
|
||||
/// let audioData = try await manager.synthesize(text: "Hello, world!")
|
||||
/// ```
|
||||
public actor PocketTtsManager {
|
||||
|
||||
private let logger = AppLogger(category: "PocketTtsManager")
|
||||
private let modelStore: PocketTtsModelStore
|
||||
private var defaultVoice: String
|
||||
private var isInitialized = false
|
||||
|
||||
/// Creates a new PocketTTS manager.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - defaultVoice: Default voice identifier (default: "alba").
|
||||
public init(defaultVoice: String = PocketTtsConstants.defaultVoice) {
|
||||
self.modelStore = PocketTtsModelStore()
|
||||
self.defaultVoice = defaultVoice
|
||||
}
|
||||
|
||||
public var isAvailable: Bool {
|
||||
isInitialized
|
||||
}
|
||||
|
||||
/// Initialize by downloading and loading all PocketTTS models.
|
||||
public func initialize() async throws {
|
||||
try await modelStore.loadIfNeeded()
|
||||
isInitialized = true
|
||||
logger.notice("PocketTtsManager initialized")
|
||||
}
|
||||
|
||||
/// Synthesize text to WAV audio data.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - text: The text to synthesize.
|
||||
/// - voice: Voice identifier (default: uses the manager's default voice).
|
||||
/// - temperature: Generation temperature (default: 0.7).
|
||||
/// - deEss: Whether to apply de-essing post-processing (default: true).
|
||||
/// - Returns: WAV audio data at 24kHz.
|
||||
public func synthesize(
|
||||
text: String,
|
||||
voice: String? = nil,
|
||||
temperature: Float = PocketTtsConstants.temperature,
|
||||
deEss: Bool = true
|
||||
) async throws -> Data {
|
||||
guard isInitialized else {
|
||||
throw TTSError.modelNotFound("PocketTTS model not initialized")
|
||||
}
|
||||
|
||||
let selectedVoice = voice ?? defaultVoice
|
||||
|
||||
return try await PocketTtsSynthesizer.withModelStore(modelStore) {
|
||||
let result = try await PocketTtsSynthesizer.synthesize(
|
||||
text: text,
|
||||
voice: selectedVoice,
|
||||
temperature: temperature,
|
||||
deEss: deEss
|
||||
)
|
||||
return result.audio
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize text and return detailed results including frame count and EOS info.
|
||||
public func synthesizeDetailed(
|
||||
text: String,
|
||||
voice: String? = nil,
|
||||
temperature: Float = PocketTtsConstants.temperature,
|
||||
deEss: Bool = true
|
||||
) async throws -> PocketTtsSynthesizer.SynthesisResult {
|
||||
guard isInitialized else {
|
||||
throw TTSError.modelNotFound("PocketTTS model not initialized")
|
||||
}
|
||||
|
||||
let selectedVoice = voice ?? defaultVoice
|
||||
|
||||
return try await PocketTtsSynthesizer.withModelStore(modelStore) {
|
||||
try await PocketTtsSynthesizer.synthesize(
|
||||
text: text,
|
||||
voice: selectedVoice,
|
||||
temperature: temperature,
|
||||
deEss: deEss
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize text and write the result directly to a file.
|
||||
public func synthesizeToFile(
|
||||
text: String,
|
||||
outputURL: URL,
|
||||
voice: String? = nil,
|
||||
temperature: Float = PocketTtsConstants.temperature,
|
||||
deEss: Bool = true
|
||||
) async throws {
|
||||
if FileManager.default.fileExists(atPath: outputURL.path) {
|
||||
try FileManager.default.removeItem(at: outputURL)
|
||||
}
|
||||
|
||||
let audioData = try await synthesize(
|
||||
text: text,
|
||||
voice: voice,
|
||||
temperature: temperature,
|
||||
deEss: deEss
|
||||
)
|
||||
|
||||
try audioData.write(to: outputURL)
|
||||
logger.notice("Saved synthesized audio to: \(outputURL.lastPathComponent)")
|
||||
}
|
||||
|
||||
/// Update the default voice.
|
||||
public func setDefaultVoice(_ voice: String) {
|
||||
defaultVoice = voice
|
||||
}
|
||||
|
||||
public func cleanup() {
|
||||
isInitialized = false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
import Foundation
|
||||
|
||||
/// Minimal protobuf parser for SentencePiece `.model` files.
|
||||
///
|
||||
/// Extracts only the vocabulary pieces (string + score) from the
|
||||
/// `ModelProto` message, ignoring trainer/normalizer specs.
|
||||
///
|
||||
/// Wire format reference:
|
||||
/// - Tag = (field_number << 3) | wire_type
|
||||
/// - Wire type 0 = varint, 2 = length-delimited, 5 = 32-bit fixed
|
||||
enum SentencePieceProto {
|
||||
|
||||
struct Piece: Sendable {
|
||||
let piece: String
|
||||
let score: Float
|
||||
}
|
||||
|
||||
enum ParseError: Error {
|
||||
case invalidData
|
||||
case unexpectedEnd
|
||||
case invalidUtf8
|
||||
}
|
||||
|
||||
/// Parse a SentencePiece `.model` file and return the vocabulary pieces.
|
||||
static func parse(_ data: Data) throws -> [Piece] {
|
||||
var pieces: [Piece] = []
|
||||
var offset = 0
|
||||
let bytes = Array(data)
|
||||
let count = bytes.count
|
||||
|
||||
while offset < count {
|
||||
let (fieldNumber, wireType) = try readTag(bytes: bytes, count: count, offset: &offset)
|
||||
|
||||
switch wireType {
|
||||
case 0:
|
||||
// Varint — skip
|
||||
_ = try readVarint(bytes: bytes, count: count, offset: &offset)
|
||||
case 1:
|
||||
// 64-bit fixed — skip 8 bytes
|
||||
offset += 8
|
||||
guard offset <= count else { throw ParseError.unexpectedEnd }
|
||||
case 2:
|
||||
// Length-delimited
|
||||
let length = try readVarint(bytes: bytes, count: count, offset: &offset)
|
||||
let end = offset + Int(length)
|
||||
guard end <= count else { throw ParseError.unexpectedEnd }
|
||||
|
||||
if fieldNumber == 1 {
|
||||
// Top-level field 1 = repeated SentencePiece message
|
||||
let piece = try parsePiece(bytes: bytes, start: offset, end: end)
|
||||
pieces.append(piece)
|
||||
}
|
||||
// Skip to end of this field regardless
|
||||
offset = end
|
||||
case 5:
|
||||
// 32-bit fixed — skip 4 bytes
|
||||
offset += 4
|
||||
guard offset <= count else { throw ParseError.unexpectedEnd }
|
||||
default:
|
||||
throw ParseError.invalidData
|
||||
}
|
||||
}
|
||||
|
||||
return pieces
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
private static func parsePiece(bytes: [UInt8], start: Int, end: Int) throws -> Piece {
|
||||
var offset = start
|
||||
var piece: String?
|
||||
var score: Float = 0
|
||||
|
||||
while offset < end {
|
||||
let (fieldNumber, wireType) = try readTag(bytes: bytes, count: end, offset: &offset)
|
||||
|
||||
switch wireType {
|
||||
case 0:
|
||||
_ = try readVarint(bytes: bytes, count: end, offset: &offset)
|
||||
case 1:
|
||||
offset += 8
|
||||
guard offset <= end else { throw ParseError.unexpectedEnd }
|
||||
case 2:
|
||||
let length = try readVarint(bytes: bytes, count: end, offset: &offset)
|
||||
let fieldEnd = offset + Int(length)
|
||||
guard fieldEnd <= end else { throw ParseError.unexpectedEnd }
|
||||
|
||||
if fieldNumber == 1 {
|
||||
// SentencePiece.piece (string)
|
||||
let slice = bytes[offset..<fieldEnd]
|
||||
guard let str = String(bytes: slice, encoding: .utf8) else {
|
||||
throw ParseError.invalidUtf8
|
||||
}
|
||||
piece = str
|
||||
}
|
||||
offset = fieldEnd
|
||||
case 5:
|
||||
if fieldNumber == 2 {
|
||||
// SentencePiece.score (float)
|
||||
guard offset + 4 <= end else { throw ParseError.unexpectedEnd }
|
||||
score = readFloat32(bytes: bytes, offset: offset)
|
||||
}
|
||||
offset += 4
|
||||
guard offset <= end else { throw ParseError.unexpectedEnd }
|
||||
default:
|
||||
throw ParseError.invalidData
|
||||
}
|
||||
}
|
||||
|
||||
return Piece(piece: piece ?? "", score: score)
|
||||
}
|
||||
|
||||
private static func readTag(
|
||||
bytes: [UInt8], count: Int, offset: inout Int
|
||||
) throws -> (fieldNumber: Int, wireType: Int) {
|
||||
let tag = try readVarint(bytes: bytes, count: count, offset: &offset)
|
||||
let wireType = Int(tag & 0x07)
|
||||
let fieldNumber = Int(tag >> 3)
|
||||
return (fieldNumber, wireType)
|
||||
}
|
||||
|
||||
private static func readVarint(
|
||||
bytes: [UInt8], count: Int, offset: inout Int
|
||||
) throws -> UInt64 {
|
||||
var result: UInt64 = 0
|
||||
var shift: UInt64 = 0
|
||||
|
||||
while offset < count {
|
||||
let byte = bytes[offset]
|
||||
offset += 1
|
||||
result |= UInt64(byte & 0x7F) << shift
|
||||
if byte & 0x80 == 0 {
|
||||
return result
|
||||
}
|
||||
shift += 7
|
||||
if shift >= 64 { throw ParseError.invalidData }
|
||||
}
|
||||
|
||||
throw ParseError.unexpectedEnd
|
||||
}
|
||||
|
||||
private static func readFloat32(bytes: [UInt8], offset: Int) -> Float {
|
||||
var value: Float = 0
|
||||
withUnsafeMutableBytes(of: &value) { ptr in
|
||||
ptr[0] = bytes[offset]
|
||||
ptr[1] = bytes[offset + 1]
|
||||
ptr[2] = bytes[offset + 2]
|
||||
ptr[3] = bytes[offset + 3]
|
||||
}
|
||||
return value
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
import Foundation
|
||||
|
||||
/// Minimal SentencePiece unigram tokenizer for PocketTTS.
|
||||
///
|
||||
/// Parses a `.model` protobuf to extract the vocabulary, then uses
|
||||
/// Viterbi decoding to segment text into subword tokens.
|
||||
public struct SentencePieceTokenizer: Sendable {
|
||||
|
||||
/// Vocabulary pieces with their log-probability scores.
|
||||
private let pieces: [SentencePieceProto.Piece]
|
||||
/// Lookup from piece string to token ID.
|
||||
private let pieceToId: [String: Int]
|
||||
/// Maximum piece length in UTF-8 scalars for early termination.
|
||||
private let maxPieceLength: Int
|
||||
|
||||
/// The space replacement character used by SentencePiece.
|
||||
private static let spaceMarker: Character = "\u{2581}"
|
||||
|
||||
public init(modelData: Data) throws {
|
||||
let parsed = try SentencePieceProto.parse(modelData)
|
||||
self.pieces = parsed
|
||||
|
||||
var lookup: [String: Int] = [:]
|
||||
lookup.reserveCapacity(parsed.count)
|
||||
var maxLen = 0
|
||||
for (index, entry) in parsed.enumerated() {
|
||||
lookup[entry.piece] = index
|
||||
maxLen = max(maxLen, entry.piece.unicodeScalars.count)
|
||||
}
|
||||
self.pieceToId = lookup
|
||||
self.maxPieceLength = maxLen
|
||||
}
|
||||
|
||||
/// Tokenize text into token IDs using Viterbi unigram decoding.
|
||||
///
|
||||
/// Applies the standard SentencePiece normalization: replaces spaces
|
||||
/// with `\u{2581}` and prepends `\u{2581}` to the input.
|
||||
public func encode(_ text: String) -> [Int] {
|
||||
guard !text.isEmpty else { return [] }
|
||||
|
||||
// Normalize: prepend space marker, replace spaces with marker
|
||||
let normalized =
|
||||
String(Self.spaceMarker)
|
||||
+ text.replacingOccurrences(
|
||||
of: " ", with: String(Self.spaceMarker))
|
||||
|
||||
return viterbiDecode(normalized)
|
||||
}
|
||||
|
||||
// MARK: - Viterbi Decoding
|
||||
|
||||
/// Run Viterbi algorithm to find the highest-score segmentation.
|
||||
///
|
||||
/// For each position in the string, finds the best-scoring
|
||||
/// vocabulary piece ending at that position.
|
||||
private func viterbiDecode(_ text: String) -> [Int] {
|
||||
let scalars = Array(text.unicodeScalars)
|
||||
let n = scalars.count
|
||||
guard n > 0 else { return [] }
|
||||
|
||||
// bestScore[i] = best log-probability score for text[0..<i]
|
||||
// bestPiece[i] = (pieceId, startPosition) for the piece ending at i
|
||||
let negInf: Float = -.infinity
|
||||
var bestScore = [Float](repeating: negInf, count: n + 1)
|
||||
var bestPiece = [(pieceId: Int, start: Int)](repeating: (0, 0), count: n + 1)
|
||||
bestScore[0] = 0
|
||||
|
||||
// Build a string from scalars for substring matching
|
||||
// We work with Unicode scalar offsets for correctness
|
||||
for i in 0..<n {
|
||||
guard bestScore[i] > negInf else { continue }
|
||||
|
||||
let maxLen = min(maxPieceLength, n - i)
|
||||
for length in 1...maxLen {
|
||||
let end = i + length
|
||||
// Build candidate substring from scalars
|
||||
let candidate = String(String.UnicodeScalarView(scalars[i..<end]))
|
||||
|
||||
guard let pieceId = pieceToId[candidate] else { continue }
|
||||
let piece = pieces[pieceId]
|
||||
|
||||
let newScore = bestScore[i] + piece.score
|
||||
if newScore > bestScore[end] {
|
||||
bestScore[end] = newScore
|
||||
bestPiece[end] = (pieceId: pieceId, start: i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backtrack to collect token IDs
|
||||
guard bestScore[n] > negInf else {
|
||||
// Fallback: encode as individual characters
|
||||
return fallbackEncode(scalars)
|
||||
}
|
||||
|
||||
var ids: [Int] = []
|
||||
var pos = n
|
||||
while pos > 0 {
|
||||
let (pieceId, start) = bestPiece[pos]
|
||||
ids.append(pieceId)
|
||||
pos = start
|
||||
}
|
||||
|
||||
ids.reverse()
|
||||
return ids
|
||||
}
|
||||
|
||||
/// Fallback: encode each character as a separate token.
|
||||
private func fallbackEncode(_ scalars: [Unicode.Scalar]) -> [Int] {
|
||||
var ids: [Int] = []
|
||||
for scalar in scalars {
|
||||
let char = String(scalar)
|
||||
if let id = pieceToId[char] {
|
||||
ids.append(id)
|
||||
}
|
||||
// Unknown characters are silently dropped
|
||||
}
|
||||
return ids
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
import Foundation
|
||||
|
||||
/// Available TTS synthesis backends.
|
||||
public enum TtsBackend: Sendable {
|
||||
/// Kokoro 82M — phoneme-based, multi-voice, chunk-oriented synthesis.
|
||||
case kokoro
|
||||
/// PocketTTS — flow-matching language model, autoregressive streaming synthesis.
|
||||
case pocketTts
|
||||
}
|
||||
@@ -13,8 +13,9 @@ final class FrameworkLinkTests: XCTestCase {
|
||||
// Simply importing and using FluidAudio tests that ESpeakNG is properly linked
|
||||
// If the framework structure is broken (binary name wrong, symlinks broken, etc),
|
||||
// this would fail with dyld errors
|
||||
let manager = TtSManager()
|
||||
XCTAssertNotNil(manager, "TtSManager should be instantiable, meaning ESpeakNG framework is properly linked")
|
||||
let manager = KokoroTtsManager()
|
||||
XCTAssertNotNil(
|
||||
manager, "KokoroTtsManager should be instantiable, meaning ESpeakNG framework is properly linked")
|
||||
}
|
||||
|
||||
/// Test that the binary can actually be found by dyld at runtime.
|
||||
@@ -23,7 +24,7 @@ final class FrameworkLinkTests: XCTestCase {
|
||||
/// - Symlink chain is broken
|
||||
/// - Framework not embedded correctly
|
||||
func testFrameworkBinaryResolution() async throws {
|
||||
let manager = TtSManager()
|
||||
let manager = KokoroTtsManager()
|
||||
|
||||
XCTExpectFailure("Framework usage may fail in test environment without models", strict: false)
|
||||
|
||||
@@ -45,7 +46,7 @@ final class FrameworkLinkTests: XCTestCase {
|
||||
/// Test that TTS functionality is accessible (requires ESpeakNG framework).
|
||||
/// This ensures the framework is not just linked but properly functional.
|
||||
func testTTSFrameworkFunctionality() async throws {
|
||||
let manager = TtSManager()
|
||||
let manager = KokoroTtsManager()
|
||||
|
||||
XCTExpectFailure("TTS may fail in CI without models", strict: false)
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import XCTest
|
||||
import FluidAudioTTS
|
||||
@testable import FluidAudio
|
||||
|
||||
final class TtSManagerTests: XCTestCase {
|
||||
final class KokoroTtsManagerTests: XCTestCase {
|
||||
|
||||
var manager: TtSManager!
|
||||
var manager: KokoroTtsManager!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
manager = TtSManager()
|
||||
manager = KokoroTtsManager()
|
||||
}
|
||||
|
||||
override func tearDown() {
|
||||
@@ -133,7 +133,7 @@ final class TtSManagerTests: XCTestCase {
|
||||
"Third text",
|
||||
]
|
||||
|
||||
// Swift 6: TtSManager is not Sendable, so use sequential synthesis
|
||||
// Swift 6: KokoroTtsManager is not Sendable, so use sequential synthesis
|
||||
var results: [Data] = []
|
||||
for text in texts {
|
||||
let result = try await manager.synthesize(text: text)
|
||||
|
||||
Reference in New Issue
Block a user