Files
sphinxtrain/python/cmusphinx/lattice_error.py
2020-11-09 22:21:54 -05:00

47 lines
1.2 KiB
Python

#!/usr/bin/env python
import sys
import os
from cmusphinx import lattice
ctl, ref, latdir = sys.argv[1:4]
prune = 0
if len(sys.argv) > 4:
prune = float(sys.argv[4])
ctl = open(ctl)
ref = open(ref)
wordcount = 0
errcount = 0
for c, r in zip(ctl, ref):
c = c.strip()
r = r.split()
del r[-1]
if len(r) == 0 or r[0] != '<s>': r.insert(0, '<s>')
if r[-1] != '</s>': r.append('</s>')
nw = len(r) - 2
r = [x for x in r if not lattice.is_filler(x)]
l = lattice.Dag()
try:
l.sphinx2dag(os.path.join(latdir, c + ".lat.gz"))
except IOError:
try:
l.sphinx2dag(os.path.join(latdir, c + ".lat"))
except IOError:
l.htk2dag(os.path.join(latdir, c + ".slf"))
if prune:
l.posterior_prune(-prune)
err, bt = l.minimum_error(r)
maxlen = [max([len(y) for y in x]) for x in bt]
print(" ".join(["%*s" % (m, x[0]) for m, x in zip(maxlen, bt)]))
print(" ".join(["%*s" % (m, x[1]) for m, x in zip(maxlen, bt)]))
if nw:
print("Error: %.2f%%" % (float(err) / nw * 100))
else:
print("Error: %.2f%%" % 0)
print()
wordcount += nw
errcount += err
print("TOTAL Error: %.2f%%" % (float(errcount) / wordcount * 100))