mirror of
https://github.com/cmusphinx/sphinxtrain.git
synced 2026-05-17 13:10:52 +00:00
203 lines
5.7 KiB
Python
203 lines
5.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
|
|
|
|
import numpy
|
|
import sys
|
|
import os
|
|
import struct
|
|
from cmusphinx import s3mixw, sendump
|
|
import sphinxbase
|
|
|
|
|
|
def mixw_kmeans_iter(lmw, cb):
|
|
cbacc = numpy.zeros(len(cb))
|
|
cbcnt = numpy.zeros(len(cb))
|
|
tdist = 0
|
|
for m in lmw:
|
|
dist = (cb - m)
|
|
dist *= dist
|
|
cw = dist.argmin()
|
|
tdist += dist.min()
|
|
cbacc[cw] += m
|
|
cbcnt[cw] += 1
|
|
cb[:] = cbacc / cbcnt
|
|
return tdist
|
|
|
|
|
|
def map_mixw_cb(mixw, cb, zero=0.0):
|
|
n_sen, n_feat, n_gau = mixw.shape
|
|
lmw = numpy.log(mixw)
|
|
mwmap = numpy.zeros(mixw.shape, 'uint8')
|
|
for s in range(0, n_sen):
|
|
for f in range(0, n_feat):
|
|
for g in range(0, n_gau):
|
|
x = mixw[s, f, g]
|
|
if x <= zero:
|
|
mwmap[s, f, g] = len(cb)
|
|
else:
|
|
dist = (cb - lmw[s, f, g])
|
|
dist *= dist
|
|
mwmap[s, f, g] = dist.argmin()
|
|
return mwmap
|
|
|
|
|
|
def mixw_freq(mixwmap):
|
|
hist = numpy.zeros(mixwmap.max() + 1, 'i')
|
|
for cw in mixwmap.ravel():
|
|
hist[cw] += 1
|
|
return hist
|
|
|
|
|
|
try:
|
|
from qmwx import mixw_kmeans_iter, map_mixw_cb, mixw_freq
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def quantize_mixw_kmeans(mixw, k, zero=0.0):
|
|
mw = mixw.ravel()
|
|
lmw = numpy.log(mw.take(numpy.greater(mw, zero).nonzero()[0]))
|
|
mmw = lmw.min()
|
|
xmw = lmw.max()
|
|
rmw = xmw - mmw
|
|
print("min log mixw: %f range: %f" % (mmw, rmw))
|
|
cb = numpy.random.random(k) * rmw + mmw
|
|
pdist = 1e+50
|
|
for i in range(0, 10):
|
|
tdist = mixw_kmeans_iter(lmw, cb)
|
|
conv = (pdist - tdist) / pdist
|
|
print("Total distortion: %e convergence ratio: %e" % (tdist, conv))
|
|
if conv < 0.01:
|
|
print("Training has converged, stopping")
|
|
break
|
|
pdist = tdist
|
|
return cb
|
|
|
|
|
|
def hb_encode(mixw):
|
|
comp = []
|
|
for i in range(0, len(mixw) - 1, 2):
|
|
comp.append((mixw[i + 1] << 4) | mixw[i])
|
|
if len(mixw) % 2:
|
|
comp.append(mixw[-1])
|
|
return comp
|
|
|
|
|
|
fmtdesc3 = \
|
|
"""BEGIN FILE FORMAT DESCRIPTION
|
|
(int32) <length(string)> (including trailing 0)
|
|
<string> (including trailing 0)
|
|
... preceding 2 items repeated any number of times
|
|
(int32) 0 (length(string)=0 terminates the header)
|
|
cluster_count centroids
|
|
cluster index array (feature_count x mixture_count x model_count)
|
|
... preceding 2 items repeated codebook_count times
|
|
END FILE FORMAT DESCRIPTION
|
|
feature_count %d
|
|
codebook_count 1
|
|
mixture_count %d
|
|
model_count %d
|
|
cluster_count %d
|
|
cluster_bits 4
|
|
logbase 1.0001
|
|
mixw_shift 10"""
|
|
|
|
|
|
def write_sendump_hb(mixwmap, cb, outfile):
|
|
n_sen, n_feat, n_gau = mixwmap.shape
|
|
fh = open(outfile, "wb")
|
|
# Write the header
|
|
fmtdesc0 = fmtdesc3 % (n_feat, n_gau, n_sen, len(cb))
|
|
for line in fmtdesc0.split('\n'):
|
|
fh.write(struct.pack('>I', len(line) + 1))
|
|
fh.write(line)
|
|
fh.write('\0')
|
|
fh.write(struct.pack('>I', 0))
|
|
# Add one extra index to the end to hold the "zero" value
|
|
qcb = numpy.resize(-(cb / numpy.log(1.0001)).astype('i') >> 10,
|
|
len(cb) + 1)
|
|
qcb[-1] = 159
|
|
qcb.astype('uint8').tofile(fh)
|
|
for f in range(0, n_feat):
|
|
for g in range(0, n_gau):
|
|
mm = numpy.array(hb_encode(mixwmap[:, f, g]), 'uint8')
|
|
mm.tofile(fh)
|
|
fh.close()
|
|
|
|
|
|
fmtdesc4 = \
|
|
"""BEGIN FILE FORMAT DESCRIPTION
|
|
(int32) <length(string)> (including trailing 0)
|
|
<string> (including trailing 0)
|
|
... preceding 2 items repeated any number of times
|
|
(int32) 0 (length(string)=0 terminates the header)
|
|
codebook_count <codebook_count>
|
|
mixture_count <mixture_count>
|
|
model_count <model_count>
|
|
cluster_count <cluster_count>
|
|
huffman_coded 1
|
|
logbase <logarithm_base>
|
|
mixw_shift <log_bits_downshifted>
|
|
<huffman codebook>
|
|
<compressed arrays of mixture weights>
|
|
END FILE FORMAT DESCRIPTION
|
|
feature_count %d
|
|
codebook_count 1
|
|
mixture_count %d
|
|
model_count %d
|
|
cluster_count %d
|
|
huffman_coded 1
|
|
logbase 1.0001
|
|
mixw_shift 10"""
|
|
|
|
|
|
def write_sendump_huff(mixwmap, cb, outfile):
|
|
n_sen, n_feat, n_gau = mixwmap.shape
|
|
fh = open(outfile, "wb")
|
|
# Write the header
|
|
fmtdesc0 = fmtdesc3 % (n_feat, n_gau, n_sen, len(cb))
|
|
for line in fmtdesc0.split('\n'):
|
|
fh.write(struct.pack('>I', len(line) + 1))
|
|
fh.write(line)
|
|
fh.write('\0')
|
|
# Terminate it with a null entry
|
|
fh.write(struct.pack('>I', 0))
|
|
# If there's an extra "floor" value then add it to the codebook
|
|
if mixwmap.max() == len(cb):
|
|
qcb = numpy.resize(-(cb / numpy.log(1.0001)).astype('i') >> 10,
|
|
len(cb) + 1)
|
|
qcb[-1] = 159
|
|
else:
|
|
qcb = numpy.resize(-(cb / numpy.log(1.0001)).astype('i') >> 10,
|
|
len(cb))
|
|
# Histogram the mixture weight map and build a Huffman codebook
|
|
hist = mixw_freq(mixwmap)
|
|
# Write the codebook (we code directly to quantized mixw values)
|
|
huff = sphinxbase.HuffCode(list(zip(qcb, hist)))
|
|
huff.write(fh)
|
|
# Now Huffman code the mixture weights (not their codebook indices
|
|
# mind you!) to the output file.
|
|
huff.attach(fh, "wb")
|
|
for f in range(0, n_feat):
|
|
for g in range(0, n_gau):
|
|
syms = [qcb[x] for x in mixwmap[:, f, g]]
|
|
huff.encode_to_file(syms)
|
|
huff.detach()
|
|
fh.close()
|
|
|
|
|
|
def norm_floor_mixw(mixw, floor=1e-7):
|
|
return (mixw.T / mixw.T.sum(0)).T.clip(floor, 1.0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
ifn, ofn = sys.argv[1:]
|
|
if os.path.basename(ifn).startswith('sendump'):
|
|
mixw = sendump.Sendump(ifn).mixw()
|
|
else:
|
|
mixw = norm_floor_mixw(s3mixw.open(ifn).getall(), 1e-7)
|
|
cb = quantize_mixw_kmeans(mixw, 15, 1e-7)
|
|
mwmap = map_mixw_cb(mixw, cb, 1e-7)
|
|
write_sendump_hb(mwmap, cb, ofn)
|