mirror of
https://github.com/cmusphinx/sphinxtrain.git
synced 2026-05-17 13:10:52 +00:00
47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
#!/usr/bin/env python
|
|
|
|
import sys
|
|
import os
|
|
from cmusphinx import lattice
|
|
|
|
ctl, ref, latdir = sys.argv[1:4]
|
|
prune = 0
|
|
if len(sys.argv) > 4:
|
|
prune = float(sys.argv[4])
|
|
|
|
ctl = open(ctl)
|
|
ref = open(ref)
|
|
wordcount = 0
|
|
errcount = 0
|
|
for c, r in zip(ctl, ref):
|
|
c = c.strip()
|
|
r = r.split()
|
|
del r[-1]
|
|
if len(r) == 0 or r[0] != '<s>': r.insert(0, '<s>')
|
|
if r[-1] != '</s>': r.append('</s>')
|
|
nw = len(r) - 2
|
|
r = [x for x in r if not lattice.is_filler(x)]
|
|
l = lattice.Dag()
|
|
try:
|
|
l.sphinx2dag(os.path.join(latdir, c + ".lat.gz"))
|
|
except IOError:
|
|
try:
|
|
l.sphinx2dag(os.path.join(latdir, c + ".lat"))
|
|
except IOError:
|
|
l.htk2dag(os.path.join(latdir, c + ".slf"))
|
|
if prune:
|
|
l.posterior_prune(-prune)
|
|
err, bt = l.minimum_error(r)
|
|
maxlen = [max([len(y) for y in x]) for x in bt]
|
|
print(" ".join(["%*s" % (m, x[0]) for m, x in zip(maxlen, bt)]))
|
|
print(" ".join(["%*s" % (m, x[1]) for m, x in zip(maxlen, bt)]))
|
|
if nw:
|
|
print("Error: %.2f%%" % (float(err) / nw * 100))
|
|
else:
|
|
print("Error: %.2f%%" % 0)
|
|
print()
|
|
wordcount += nw
|
|
errcount += err
|
|
|
|
print("TOTAL Error: %.2f%%" % (float(errcount) / wordcount * 100))
|