Skip to content

Commit 1cd7ee9

Browse files
authored
[src] Incremental determinization [cleaned up/rewrite] (#3737)
1 parent 413c7c8 commit 1cd7ee9

23 files changed

+3747
-48
lines changed

src/bin/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \
2222
matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \
2323
vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \
2424
transform-vec align-text matrix-dim post-to-smat compile-graph \
25-
compare-int-vector compute-gop
25+
compare-int-vector latgen-incremental-mapped compute-gop
2626

2727

2828
OBJFILES =

src/bin/latgen-incremental-mapped.cc

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// bin/latgen-incremental-mapped.cc
2+
3+
// Copyright 2019 Zhehuai Chen
4+
5+
// See ../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "base/kaldi-common.h"
21+
#include "util/common-utils.h"
22+
#include "tree/context-dep.h"
23+
#include "hmm/transition-model.h"
24+
#include "fstext/fstext-lib.h"
25+
#include "decoder/decoder-wrappers.h"
26+
#include "decoder/decodable-matrix.h"
27+
#include "base/timer.h"
28+
29+
int main(int argc, char *argv[]) {
30+
try {
31+
using namespace kaldi;
32+
typedef kaldi::int32 int32;
33+
using fst::SymbolTable;
34+
using fst::Fst;
35+
using fst::StdArc;
36+
37+
const char *usage =
38+
"Generate lattices, reading log-likelihoods as matrices\n"
39+
" (model is needed only for the integer mappings in its transition-model)\n"
40+
"The lattice determinization algorithm here can operate\n"
41+
"incrementally.\n"
42+
"Usage: latgen-incremental-mapped [options] trans-model-in "
43+
"(fst-in|fsts-rspecifier) loglikes-rspecifier"
44+
" lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n";
45+
ParseOptions po(usage);
46+
Timer timer;
47+
bool allow_partial = false;
48+
BaseFloat acoustic_scale = 0.1;
49+
LatticeIncrementalDecoderConfig config;
50+
51+
std::string word_syms_filename;
52+
config.Register(&po);
53+
po.Register("acoustic-scale", &acoustic_scale,
54+
"Scaling factor for acoustic likelihoods");
55+
56+
po.Register("word-symbol-table", &word_syms_filename,
57+
"Symbol table for words [for debug output]");
58+
po.Register("allow-partial", &allow_partial,
59+
"If true, produce output even if end state was not reached.");
60+
61+
po.Read(argc, argv);
62+
63+
if (po.NumArgs() < 4 || po.NumArgs() > 6) {
64+
po.PrintUsage();
65+
exit(1);
66+
}
67+
68+
std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2),
69+
feature_rspecifier = po.GetArg(3), lattice_wspecifier = po.GetArg(4),
70+
words_wspecifier = po.GetOptArg(5),
71+
alignment_wspecifier = po.GetOptArg(6);
72+
73+
TransitionModel trans_model;
74+
ReadKaldiObject(model_in_filename, &trans_model);
75+
76+
bool determinize = true;
77+
CompactLatticeWriter compact_lattice_writer;
78+
LatticeWriter lattice_writer;
79+
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier)
80+
: lattice_writer.Open(lattice_wspecifier)))
81+
KALDI_ERR << "Could not open table for writing lattices: "
82+
<< lattice_wspecifier;
83+
84+
Int32VectorWriter words_writer(words_wspecifier);
85+
86+
Int32VectorWriter alignment_writer(alignment_wspecifier);
87+
88+
fst::SymbolTable *word_syms = NULL;
89+
if (word_syms_filename != "")
90+
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
91+
KALDI_ERR << "Could not read symbol table from file " << word_syms_filename;
92+
93+
double tot_like = 0.0;
94+
kaldi::int64 frame_count = 0;
95+
int num_success = 0, num_fail = 0;
96+
97+
if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) {
98+
SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier);
99+
// Input FST is just one FST, not a table of FSTs.
100+
Fst<StdArc> *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str);
101+
timer.Reset();
102+
103+
{
104+
LatticeIncrementalDecoder decoder(*decode_fst, trans_model, config);
105+
106+
for (; !loglike_reader.Done(); loglike_reader.Next()) {
107+
std::string utt = loglike_reader.Key();
108+
Matrix<BaseFloat> loglikes(loglike_reader.Value());
109+
loglike_reader.FreeCurrent();
110+
if (loglikes.NumRows() == 0) {
111+
KALDI_WARN << "Zero-length utterance: " << utt;
112+
num_fail++;
113+
continue;
114+
}
115+
116+
DecodableMatrixScaledMapped decodable(trans_model, loglikes,
117+
acoustic_scale);
118+
119+
double like;
120+
if (DecodeUtteranceLatticeIncremental(
121+
decoder, decodable, trans_model, word_syms, utt, acoustic_scale,
122+
determinize, allow_partial, &alignment_writer, &words_writer,
123+
&compact_lattice_writer, &lattice_writer, &like)) {
124+
tot_like += like;
125+
frame_count += loglikes.NumRows();
126+
num_success++;
127+
} else {
128+
num_fail++;
129+
}
130+
}
131+
}
132+
delete decode_fst; // delete this only after decoder goes out of scope.
133+
} else { // We have different FSTs for different utterances.
134+
SequentialTableReader<fst::VectorFstHolder> fst_reader(fst_in_str);
135+
RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier);
136+
for (; !fst_reader.Done(); fst_reader.Next()) {
137+
std::string utt = fst_reader.Key();
138+
if (!loglike_reader.HasKey(utt)) {
139+
KALDI_WARN << "Not decoding utterance " << utt
140+
<< " because no loglikes available.";
141+
num_fail++;
142+
continue;
143+
}
144+
const Matrix<BaseFloat> &loglikes = loglike_reader.Value(utt);
145+
if (loglikes.NumRows() == 0) {
146+
KALDI_WARN << "Zero-length utterance: " << utt;
147+
num_fail++;
148+
continue;
149+
}
150+
LatticeIncrementalDecoder decoder(fst_reader.Value(), trans_model, config);
151+
DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale);
152+
double like;
153+
if (DecodeUtteranceLatticeIncremental(
154+
decoder, decodable, trans_model, word_syms, utt, acoustic_scale,
155+
determinize, allow_partial, &alignment_writer, &words_writer,
156+
&compact_lattice_writer, &lattice_writer, &like)) {
157+
tot_like += like;
158+
frame_count += loglikes.NumRows();
159+
num_success++;
160+
} else {
161+
num_fail++;
162+
}
163+
}
164+
}
165+
166+
double elapsed = timer.Elapsed();
167+
KALDI_LOG << "Time taken " << elapsed
168+
<< "s: real-time factor assuming 100 frames/sec is "
169+
<< (elapsed * 100.0 / frame_count);
170+
KALDI_LOG << "Done " << num_success << " utterances, failed for " << num_fail;
171+
KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like / frame_count)
172+
<< " over " << frame_count << " frames.";
173+
174+
delete word_syms;
175+
if (num_success != 0)
176+
return 0;
177+
else
178+
return 1;
179+
} catch (const std::exception &e) {
180+
std::cerr << e.what();
181+
return -1;
182+
}
183+
}

src/decoder/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ TESTFILES =
77

88
OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \
99
lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \
10-
decoder-wrappers.o grammar-fst.o decodable-matrix.o
10+
decoder-wrappers.o grammar-fst.o decodable-matrix.o \
11+
lattice-incremental-decoder.o lattice-incremental-online-decoder.o
1112

1213
LIBNAME = kaldi-decoder
1314

src/decoder/decoder-wrappers.cc

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void DecodeUtteranceLatticeFasterClass::operator () () {
6868
success_ = true;
6969
using fst::VectorFst;
7070
if (!decoder_->Decode(decodable_)) {
71-
KALDI_WARN << "Failed to decode file " << utt_;
71+
KALDI_WARN << "Failed to decode utterance with id " << utt_;
7272
success_ = false;
7373
}
7474
if (!decoder_->ReachedFinal()) {
@@ -195,6 +195,92 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() {
195195
delete decodable_;
196196
}
197197

198+
template <typename FST>
199+
bool DecodeUtteranceLatticeIncremental(
200+
LatticeIncrementalDecoderTpl<FST> &decoder, // not const but is really an input.
201+
DecodableInterface &decodable, // not const but is really an input.
202+
const TransitionModel &trans_model,
203+
const fst::SymbolTable *word_syms,
204+
std::string utt,
205+
double acoustic_scale,
206+
bool determinize,
207+
bool allow_partial,
208+
Int32VectorWriter *alignment_writer,
209+
Int32VectorWriter *words_writer,
210+
CompactLatticeWriter *compact_lattice_writer,
211+
LatticeWriter *lattice_writer,
212+
double *like_ptr) { // puts utterance's like in like_ptr on success.
213+
using fst::VectorFst;
214+
if (!decoder.Decode(&decodable)) {
215+
KALDI_WARN << "Failed to decode utterance with id " << utt;
216+
return false;
217+
}
218+
if (!decoder.ReachedFinal()) {
219+
if (allow_partial) {
220+
KALDI_WARN << "Outputting partial output for utterance " << utt
221+
<< " since no final-state reached\n";
222+
} else {
223+
KALDI_WARN << "Not producing output for utterance " << utt
224+
<< " since no final-state reached and "
225+
<< "--allow-partial=false.\n";
226+
return false;
227+
}
228+
}
229+
230+
// Get lattice
231+
CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), true);
232+
if (clat.NumStates() == 0)
233+
KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
234+
235+
double likelihood;
236+
LatticeWeight weight;
237+
int32 num_frames;
238+
{ // First do some stuff with word-level traceback...
239+
CompactLattice decoded_clat;
240+
CompactLatticeShortestPath(clat, &decoded_clat);
241+
Lattice decoded;
242+
fst::ConvertLattice(decoded_clat, &decoded);
243+
244+
if (decoded.Start() == fst::kNoStateId)
245+
// Shouldn't really reach this point as already checked success.
246+
KALDI_ERR << "Failed to get traceback for utterance " << utt;
247+
248+
std::vector<int32> alignment;
249+
std::vector<int32> words;
250+
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
251+
num_frames = alignment.size();
252+
KALDI_ASSERT(num_frames == decoder.NumFramesDecoded());
253+
if (words_writer->IsOpen())
254+
words_writer->Write(utt, words);
255+
if (alignment_writer->IsOpen())
256+
alignment_writer->Write(utt, alignment);
257+
if (word_syms != NULL) {
258+
std::cerr << utt << ' ';
259+
for (size_t i = 0; i < words.size(); i++) {
260+
std::string s = word_syms->Find(words[i]);
261+
if (s == "")
262+
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
263+
std::cerr << s << ' ';
264+
}
265+
std::cerr << '\n';
266+
}
267+
likelihood = -(weight.Value1() + weight.Value2());
268+
}
269+
270+
// We'll write the lattice without acoustic scaling.
271+
if (acoustic_scale != 0.0)
272+
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
273+
Connect(&clat);
274+
compact_lattice_writer->Write(utt, clat);
275+
KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
276+
<< (likelihood / num_frames) << " over "
277+
<< num_frames << " frames.";
278+
KALDI_VLOG(2) << "Cost for utterance " << utt << " is "
279+
<< weight.Value1() << " + " << weight.Value2();
280+
*like_ptr = likelihood;
281+
return true;
282+
}
283+
198284

199285
// Takes care of output. Returns true on success.
200286
template <typename FST>
@@ -215,7 +301,7 @@ bool DecodeUtteranceLatticeFaster(
215301
using fst::VectorFst;
216302

217303
if (!decoder.Decode(&decodable)) {
218-
KALDI_WARN << "Failed to decode file " << utt;
304+
KALDI_WARN << "Failed to decode utterance with id " << utt;
219305
return false;
220306
}
221307
if (!decoder.ReachedFinal()) {
@@ -296,6 +382,37 @@ bool DecodeUtteranceLatticeFaster(
296382
}
297383

298384
// Instantiate the template above for the two required FST types.
385+
template bool DecodeUtteranceLatticeIncremental(
386+
LatticeIncrementalDecoderTpl<fst::Fst<fst::StdArc> > &decoder,
387+
DecodableInterface &decodable,
388+
const TransitionModel &trans_model,
389+
const fst::SymbolTable *word_syms,
390+
std::string utt,
391+
double acoustic_scale,
392+
bool determinize,
393+
bool allow_partial,
394+
Int32VectorWriter *alignment_writer,
395+
Int32VectorWriter *words_writer,
396+
CompactLatticeWriter *compact_lattice_writer,
397+
LatticeWriter *lattice_writer,
398+
double *like_ptr);
399+
400+
template bool DecodeUtteranceLatticeIncremental(
401+
LatticeIncrementalDecoderTpl<fst::GrammarFst> &decoder,
402+
DecodableInterface &decodable,
403+
const TransitionModel &trans_model,
404+
const fst::SymbolTable *word_syms,
405+
std::string utt,
406+
double acoustic_scale,
407+
bool determinize,
408+
bool allow_partial,
409+
Int32VectorWriter *alignment_writer,
410+
Int32VectorWriter *words_writer,
411+
CompactLatticeWriter *compact_lattice_writer,
412+
LatticeWriter *lattice_writer,
413+
double *like_ptr);
414+
415+
299416
template bool DecodeUtteranceLatticeFaster(
300417
LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> > &decoder,
301418
DecodableInterface &decodable,
@@ -345,7 +462,7 @@ bool DecodeUtteranceLatticeSimple(
345462
using fst::VectorFst;
346463

347464
if (!decoder.Decode(&decodable)) {
348-
KALDI_WARN << "Failed to decode file " << utt;
465+
KALDI_WARN << "Failed to decode utterance with id " << utt;
349466
return false;
350467
}
351468
if (!decoder.ReachedFinal()) {

src/decoder/decoder-wrappers.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "itf/options-itf.h"
2424
#include "decoder/lattice-faster-decoder.h"
25+
#include "decoder/lattice-incremental-decoder.h"
2526
#include "decoder/lattice-simple-decoder.h"
2627

2728
// This header contains declarations from various convenience functions that are called
@@ -88,6 +89,23 @@ void AlignUtteranceWrapper(
8889
void ModifyGraphForCarefulAlignment(
8990
fst::VectorFst<fst::StdArc> *fst);
9091

92+
/// TODO
93+
template <typename FST>
94+
bool DecodeUtteranceLatticeIncremental(
95+
LatticeIncrementalDecoderTpl<FST> &decoder, // not const but is really an input.
96+
DecodableInterface &decodable, // not const but is really an input.
97+
const TransitionModel &trans_model,
98+
const fst::SymbolTable *word_syms,
99+
std::string utt,
100+
double acoustic_scale,
101+
bool determinize,
102+
bool allow_partial,
103+
Int32VectorWriter *alignments_writer,
104+
Int32VectorWriter *words_writer,
105+
CompactLatticeWriter *compact_lattice_writer,
106+
LatticeWriter *lattice_writer,
107+
double *like_ptr); // puts utterance's likelihood in like_ptr on success.
108+
91109

92110
/// This function DecodeUtteranceLatticeFaster is used in several decoders, and
93111
/// we have moved it here. Note: this is really "binary-level" code as it

0 commit comments

Comments
 (0)