diff --git a/data/data_features/tgt-train-with-feats.txt b/data/data_features/tgt-train-with-feats.txt new file mode 100644 index 0000000000..42cd4995a4 --- /dev/null +++ b/data/data_features/tgt-train-with-feats.txt @@ -0,0 +1,3 @@ +however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C +however,│A according│B to│C the│D logs,│E +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/data/data_features/tgt-val-with-feats.txt b/data/data_features/tgt-val-with-feats.txt new file mode 100644 index 0000000000..4a41985c0d --- /dev/null +++ b/data/data_features/tgt-val-with-feats.txt @@ -0,0 +1 @@ +she│C is│B a│A hard-working.│B \ No newline at end of file diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index 6f2611a5e8..35f0ed5be8 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -45,6 +45,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): sub_counter_src = Counter() sub_counter_tgt = Counter() sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] + sub_counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)] datasets_iterables = build_corpora_iters( corpora, transforms, opts.data, skip_empty_level=opts.skip_empty_level, @@ -63,26 +64,36 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): sub_counter_src.update(src_line.split(' ')) sub_counter_tgt.update(tgt_line.split(' ')) + src_feats_lines = [] if 'feats' in maybe_example['src']: src_feats_lines = maybe_example['src']['feats'] for i in range(opts.n_src_feats): sub_counter_src_feats[i].update( src_feats_lines[i].split(' ')) - else: - src_feats_lines = [] + + tgt_feats_lines = [] + if maybe_example["tgt"] is not None: + if 'feats' in maybe_example['tgt']: + tgt_feats_lines = maybe_example['tgt']['feats'] + for i in range(opts.n_tgt_feats): + sub_counter_tgt_feats[i].update( + tgt_feats_lines[i].split(' ')) if opts.dump_samples: src_pretty_line = append_features_to_text( src_line, src_feats_lines) + tgt_pretty_line = append_features_to_text( + tgt_line, tgt_feats_lines) build_sub_vocab.queues[c_name][offset].put( - (i, src_pretty_line, tgt_line)) + (i, src_pretty_line, tgt_pretty_line)) if n_sample > 0 and ((i+1) * stride + offset) >= n_sample: if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") break if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") - return sub_counter_src, sub_counter_tgt, sub_counter_src_feats + return (sub_counter_src, sub_counter_tgt, + sub_counter_src_feats, sub_counter_tgt_feats) def init_pool(queues): @@ -107,6 +118,7 @@ def build_vocab(opts, transforms, n_sample=3): counter_src = Counter() counter_tgt = Counter() counter_src_feats = [Counter() for _ in range(opts.n_src_feats)] + counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)] from functools import partial queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)] @@ -123,15 +135,18 @@ def build_vocab(opts, transforms, n_sample=3): func = partial( build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads) - for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap( + for (sub_counter_src, sub_counter_tgt, + sub_counter_src_feats, sub_counter_tgt_feats) in p.imap( func, range(0, opts.num_threads)): counter_src.update(sub_counter_src) counter_tgt.update(sub_counter_tgt) for i in range(opts.n_src_feats): counter_src_feats[i].update(sub_counter_src_feats[i]) + for i in range(opts.n_tgt_feats): + counter_tgt_feats[i].update(sub_counter_tgt_feats[i]) if opts.dump_samples: write_process.join() - return counter_src, counter_tgt, counter_src_feats + return counter_src, counter_tgt, counter_src_feats, counter_tgt_feats def build_vocab_main(opts): @@ -157,13 +172,15 @@ def build_vocab_main(opts): transforms = make_transforms(opts, transforms_cls, None) logger.info(f"Counter vocab from {opts.n_sample} samples.") - src_counter, tgt_counter, src_feats_counter = build_vocab( - opts, transforms, n_sample=opts.n_sample) + src_counter, tgt_counter, src_feats_counter, tgt_feats_counter = \ + build_vocab(opts, transforms, n_sample=opts.n_sample) logger.info(f"Counters src: {len(src_counter)}") logger.info(f"Counters tgt: {len(tgt_counter)}") for i, feat_counter in enumerate(src_feats_counter): logger.info(f"Counters src feat_{i}: {len(feat_counter)}") + for i, feat_counter in enumerate(tgt_feats_counter): + logger.info(f"Counters tgt feat_{i}: {len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) @@ -182,6 +199,8 @@ def save_counter(counter, save_path): for i, c in enumerate(src_feats_counter): save_counter(c, f"{opts.src_vocab}_feat{i}") + for i, c in enumerate(tgt_feats_counter): + save_counter(c, f"{opts.tgt_vocab}_feat{i}") def _get_parser(): diff --git a/onmt/decoders/ensemble.py b/onmt/decoders/ensemble.py index 31ad6f2509..53b76b7dc6 100644 --- a/onmt/decoders/ensemble.py +++ b/onmt/decoders/ensemble.py @@ -99,14 +99,29 @@ def forward(self, hidden, attn=None, src_map=None): by averaging distributions from models in the ensemble. All models in the ensemble must share a target vocabulary. """ - distributions = torch.stack( - [mg(h) if attn is None else mg(h, attn, src_map) - for h, mg in zip(hidden, self.model_generators)] - ) + distributions, feats_distributions = [], [] + n_feats = len(self.model_generators[0].feats_generators) + for h, mg in zip(hidden, self.model_generators): + scores, feats_scores = \ + (mg(h) if attn is None else mg(h, attn, src_map)) + distributions.append(scores) + feats_distributions.append(feats_scores) + + distributions = torch.stack(distributions) + + stacked_feats_distributions = [] + for i in range(n_feats): + stacked_feats_distributions.append( + torch.stack([feat_distribution[i] + for feat_distribution in feats_distributions + for i in range(n_feats)])) if self._raw_probs: - return torch.log(torch.exp(distributions).mean(0)) + return (torch.log(torch.exp(distributions).mean(0)), + [torch.log(torch.exp(d).mean(0)) + for d in stacked_feats_distributions]) else: - return distributions.mean(0) + return (distributions.mean(0), + [d.mean(0) for d in stacked_feats_distributions]) class EnsembleModel(NMTModel): diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 65463c713b..dc354d4900 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -34,10 +34,11 @@ def build_vocab(opt, specials): """ Build vocabs dict to be stored in the checkpoint based on vocab files having each line [token, count] Args: - opt: src_vocab, tgt_vocab, n_src_feats + opt: src_vocab, tgt_vocab, n_src_feats, n_tgt_feats Return: vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab, 'src_feats' : [pyonmttok.Vocab, ...]}, + 'tgt_feats' : [pyonmttok.Vocab, ...]}, 'data_task': seq2seq or lm } """ @@ -103,6 +104,25 @@ def _pad_vocab_to_multiple(vocab, multiple): src_feats_vocabs.append(src_f_vocab) vocabs["src_feats"] = src_feats_vocabs + if opt.n_tgt_feats > 0: + tgt_feats_vocabs = [] + for i in range(opt.n_tgt_feats): + tgt_f_vocab = _read_vocab_file(f"{opt.tgt_vocab}_feat{i}", 1) + tgt_f_vocab = pyonmttok.build_vocab_from_tokens( + tgt_f_vocab, + maximum_size=0, + minimum_frequency=1, + special_tokens=[DefaultTokens.UNK, + DefaultTokens.PAD, + DefaultTokens.BOS, + DefaultTokens.EOS]) + tgt_f_vocab.default_id = tgt_f_vocab[DefaultTokens.UNK] + if opt.vocab_size_multiple > 1: + tgt_f_vocab = _pad_vocab_to_multiple(tgt_f_vocab, + opt.vocab_size_multiple) + tgt_feats_vocabs.append(tgt_f_vocab) + vocabs["tgt_feats"] = tgt_feats_vocabs + vocabs['data_task'] = opt.data_task return vocabs @@ -147,6 +167,9 @@ def vocabs_to_dict(vocabs): if 'src_feats' in vocabs.keys(): vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens for feat_vocab in vocabs['src_feats']] + if 'tgt_feats' in vocabs.keys(): + vocabs_dict['tgt_feats'] = [feat_vocab.ids_to_tokens + for feat_vocab in vocabs['tgt_feats']] vocabs_dict['data_task'] = vocabs['data_task'] return vocabs_dict @@ -168,4 +191,9 @@ def dict_to_vocabs(vocabs_dict): for feat_vocab in vocabs_dict['src_feats']: vocabs['src_feats'].append( pyonmttok.build_vocab_from_tokens(feat_vocab)) + if 'tgt_feats' in vocabs_dict.keys(): + vocabs['tgt_feats'] = [] + for feat_vocab in vocabs_dict['tgt_feats']: + vocabs['tgt_feats'].append( + pyonmttok.build_vocab_from_tokens(feat_vocab)) return vocabs diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index e57261d565..5e1926d1ae 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -40,7 +40,8 @@ class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" def __init__(self, name, src, tgt, align=None, - n_src_feats=0, src_feats_defaults=None): + n_src_feats=0, src_feats_defaults=None, + n_tgt_feats=0, tgt_feats_defaults=None): """Initialize src & tgt side file path.""" self.id = name self.src = src @@ -48,6 +49,8 @@ def __init__(self, name, src, tgt, align=None, self.align = align self.n_src_feats = n_src_feats self.src_feats_defaults = src_feats_defaults + self.n_tgt_feats = n_tgt_feats + self.tgt_feats_defaults = tgt_feats_defaults def load(self, offset=0, stride=1): """ @@ -68,6 +71,10 @@ def load(self, offset=0, stride=1): defaults=self.src_feats_defaults) if tline is not None: tline = tline.decode('utf-8') + tline, tfeats = parse_features( + tline, + n_feats=self.n_tgt_feats, + defaults=self.tgt_feats_defaults) # 'src_original' and 'tgt_original' store the # original line before tokenization. These # fields are used later on in the feature @@ -82,7 +89,9 @@ def load(self, offset=0, stride=1): example['align'] = align.decode('utf-8') if sfeats is not None: - example['src_feats'] = [f for f in sfeats] + example['src_feats'] = sfeats + if tline is not None and tfeats is not None: + example['tgt_feats'] = tfeats yield example def __str__(self): @@ -90,7 +99,9 @@ def __str__(self): return f'{cls_name}({self.id}, {self.src}, {self.tgt}, ' \ f'align={self.align}, ' \ f'n_src_feats={self.n_src_feats}, ' \ - f'src_feats_defaults="{self.src_feats_defaults}")' + f'src_feats_defaults="{self.src_feats_defaults}", ' \ + f'n_tgt_feats={self.n_tgt_feats}, ' \ + f'tgt_feats_defaults="{self.tgt_feats_defaults}")' def get_corpora(opts, task=CorpusTask.TRAIN): @@ -104,7 +115,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN): corpus_dict["path_tgt"], corpus_dict["path_align"], n_src_feats=opts.n_src_feats, - src_feats_defaults=opts.src_feats_defaults) + src_feats_defaults=opts.src_feats_defaults, + n_tgt_feats=opts.n_tgt_feats, + tgt_feats_defaults=opts.tgt_feats_defaults) elif task == CorpusTask.VALID: if CorpusName.VALID in opts.data.keys(): corpora_dict[CorpusName.VALID] = ParallelCorpus( @@ -113,7 +126,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN): opts.data[CorpusName.VALID]["path_tgt"], opts.data[CorpusName.VALID]["path_align"], n_src_feats=opts.n_src_feats, - src_feats_defaults=opts.src_feats_defaults) + src_feats_defaults=opts.src_feats_defaults, + n_tgt_feats=opts.n_tgt_feats, + tgt_feats_defaults=opts.tgt_feats_defaults) else: return None else: @@ -122,7 +137,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN): opts.src, opts.tgt, n_src_feats=opts.n_src_feats, - src_feats_defaults=opts.src_feats_defaults) + src_feats_defaults=opts.src_feats_defaults, + n_tgt_feats=opts.n_tgt_feats, + tgt_feats_defaults=opts.tgt_feats_defaults) return corpora_dict diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 0ccde14af0..4c386ad579 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -11,7 +11,7 @@ from onmt.encoders import str2enc from onmt.decoders import str2dec from onmt.inputters.inputter import dict_to_vocabs -from onmt.modules import Embeddings, CopyGenerator +from onmt.modules import Embeddings, Generator from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser @@ -193,12 +193,11 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint): emb_name ][old_i] if side == 'tgt': - generator.state_dict()['weight'][i] = checkpoint[ - 'generator' - ]['weight'][old_i] - generator.state_dict()['bias'][i] = checkpoint[ - 'generator' - ]['bias'][old_i] + # TODO: check feats generators + generator.state_dict()['tgt_generator.weight'][i] = \ + checkpoint['generator']['tgt_generator.weight'][old_i] + generator.state_dict()['tgt_generator.bias'][i] = \ + checkpoint['generator']['tgt_generator.bias'][old_i] else: # Just for debugging purposes new_tokens.append(tok) @@ -206,7 +205,37 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint): # Remove old vocabulary associated embeddings del checkpoint['model'][emb_name] - del checkpoint['generator']['weight'], checkpoint['generator']['bias'] + del checkpoint['generator']['tgt_generator.weight'] + del checkpoint['generator']['tgt_generator.bias'] + + +def build_generator(model_opt, vocabs, decoder): + gen_sizes = [len(vocabs['tgt'])] + if 'tgt_feats' in vocabs: + gen_sizes += [len(feat_vocab) for feat_vocab in vocabs['tgt_feats']] + + if model_opt.share_decoder_embeddings: + hid_sizes = ([model_opt.dec_hid_size - + (model_opt.feat_vec_size * (len(gen_sizes) - 1))] + + [model_opt.feat_vec_size] * (len(gen_sizes) - 1)) + else: + hid_sizes = [model_opt.dec_hid_size] * len(gen_sizes) + + pad_idx = vocabs['tgt'][DefaultTokens.PAD] + generator = Generator(hid_sizes, gen_sizes, + shared=model_opt.share_decoder_embeddings, + copy_attn=model_opt.copy_attn, + pad_idx=pad_idx) + + if model_opt.share_decoder_embeddings: + if not model_opt.share_decoder_embeddings: + generator.generators[0].weight = \ + decoder.embeddings.word_lut.weight + else: + generator.generators[0].linear.weight = \ + decoder.embeddings.word_lut.weight + + return generator def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None): @@ -243,18 +272,9 @@ def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None): model = build_task_specific_model(model_opt, vocabs) - # Build Generator. - if not model_opt.copy_attn: - generator = nn.Linear(model_opt.dec_hid_size, - len(vocabs['tgt'])) - if model_opt.share_decoder_embeddings: - generator.weight = model.decoder.embeddings.word_lut.weight - else: - vocab_size = len(vocabs['tgt']) - pad_idx = vocabs['tgt'][DefaultTokens.PAD] - generator = CopyGenerator(model_opt.dec_hid_size, vocab_size, pad_idx) - if model_opt.share_decoder_embeddings: - generator.linear.weight = model.decoder.embeddings.word_lut.weight + # Build Generators + # Next token prediction and possibly target features generators + generator = build_generator(model_opt, vocabs, model.decoder) # Load the model states from checkpoint or initialize them. if checkpoint is None or model_opt.update_vocab: diff --git a/onmt/models/model.py b/onmt/models/model.py index 7246b2f9ee..91f6839aa6 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -66,7 +66,7 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): * enc_out + enc_final_hs in the case of CNNs * src in the case of Transformer """ - dec_in = tgt[:, :-1, :] + dec_in = tgt[:, :-1, :1] enc_out, enc_final_hs, src_len = self.encoder(src, src_len) if not bptt: self.decoder.init_state(src, enc_out, enc_final_hs) diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 44f3d51c9d..25f18578b9 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,6 +3,7 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention +from onmt.modules.generator import Generator from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss from onmt.modules.multi_headed_attn import MultiHeadedAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding @@ -11,7 +12,7 @@ __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", + "CopyGeneratorLoss", "Generator", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", "WeightNormConv2d", "AverageAttention", "CopyGeneratorLMLossCompute"] diff --git a/onmt/modules/criterions.py b/onmt/modules/criterions.py new file mode 100644 index 0000000000..002006a813 --- /dev/null +++ b/onmt/modules/criterions.py @@ -0,0 +1,39 @@ +import onmt +from onmt.constants import DefaultTokens +from onmt.modules.sparse_losses import SparsemaxLoss +import torch.nn as nn + + +class Criterions: + + def __init__(self, opt, vocabs): + tgt_vocab = vocabs['tgt'] + padding_idx = tgt_vocab[DefaultTokens.PAD] + unk_idx = tgt_vocab[DefaultTokens.UNK] + + if opt.copy_attn: + self.tgt_criterion = onmt.modules.CopyGeneratorLoss( + len(tgt_vocab), opt.copy_attn_force, + unk_index=unk_idx, ignore_index=padding_idx + ) + else: + if opt.generator_function == 'sparsemax': + self.tgt_criterion = SparsemaxLoss( + ignore_index=padding_idx, + reduction='sum') + else: + self.tgt_criterion = nn.CrossEntropyLoss( + ignore_index=padding_idx, + reduction='sum', + label_smoothing=opt.label_smoothing) + + # Add as many criterios as tgt features we have + self.feats_criterions = [] + if 'tgt_feats' in vocabs: + for feat_vocab in vocabs["tgt_feats"]: + padding_idx = feat_vocab[DefaultTokens.PAD] + self.feats_criterions.append( + nn.CrossEntropyLoss( + ignore_index=padding_idx, + reduction='sum') + ) \ No newline at end of file diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py new file mode 100644 index 0000000000..aa79496232 --- /dev/null +++ b/onmt/modules/generator.py @@ -0,0 +1,40 @@ +""" Onmt NMT Model base class definition """ +import torch.nn as nn + +from onmt.modules.copy_generator import CopyGenerator + + +class Generator(nn.Module): + + def __init__(self, hid_sizes, gen_sizes, + shared=False, copy_attn=False, pad_idx=None): + super(Generator, self).__init__() + self.feats_generators = nn.ModuleList() + self.shared = shared + self.hid_sizes = hid_sizes + self.gen_sizes = gen_sizes + + def simple_generator(hid_size, gen_size): + return nn.Linear(hid_size, gen_size) + + # First generator: next token prediction + if copy_attn: + self.tgt_generator = \ + CopyGenerator(hid_sizes[0], gen_sizes[0], pad_idx) + else: + self.tgt_generator = \ + simple_generator(hid_sizes[0], gen_sizes[0]) + + # Additional generators: target features + for hid_size, gen_size in zip(hid_sizes[1:], gen_sizes[1:]): + self.feats_generators.append( + simple_generator(hid_size, gen_size)) + + def forward(self, dec_out, *args): + scores = self.tgt_generator(dec_out, *args) + + feats_scores = [] + for generator in self.feats_generators: + feats_scores.append(generator(dec_out)) + + return scores, feats_scores \ No newline at end of file diff --git a/onmt/opts.py b/onmt/opts.py index 5104d4d6c5..71c32710e7 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -135,6 +135,11 @@ def _add_features_opts(parser): group.add("-src_feats_defaults", "--src_feats_defaults", help="Default features to apply in source in case " "there are not annotated") + group.add("-n_tgt_feats", "--n_tgt_feats", type=int, + default=0, help="Number of target feats.") + group.add("-tgt_feats_defaults", "--tgt_feats_defaults", + help="Default features to apply in target in case " + "there are not annotated") def _add_dynamic_vocab_opts(parser, build_vocab_only=False): diff --git a/onmt/train_single.py b/onmt/train_single.py index 624e3e5d3b..fa7c0d1b30 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -158,6 +158,9 @@ def main(opt, device_id): if "src_feats" in vocabs: for i, feat_vocab in enumerate(vocabs["src_feats"]): logger.info(f'* src_feat {i} vocab size = {len(feat_vocab)}') + if "tgt_feats" in vocabs: + for i, feat_vocab in enumerate(vocabs["tgt_feats"]): + logger.info(f'* tgt_feat {i} vocab size = {len(feat_vocab)}') # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) diff --git a/onmt/trainer.py b/onmt/trainer.py index 516fffadd8..71ad5297e7 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -35,8 +35,8 @@ def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): used to save the model """ - train_loss = LossCompute.from_opts(opt, model, vocabs['tgt']) - valid_loss = LossCompute.from_opts(opt, model, vocabs['tgt'], train=False) + train_loss = LossCompute.from_opts(opt, model, vocabs) + valid_loss = LossCompute.from_opts(opt, model, vocabs, train=False) scoring_preparator = ScoringPreparator(vocabs=vocabs, opt=opt) validset_transforms = opt.data.get("valid", {}).get("transforms", None) diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 638ac8a065..8a4cacaee3 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -29,19 +29,35 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # Do nothing return example - if self.reversible_tokenization == "joiner": - original_src = example["src_original"] - word_to_subword_mapping = subword_map_by_joiner( - example["src"], original_subwords=original_src) - else: # Spacer - word_to_subword_mapping = subword_map_by_spacer(example["src"]) - - new_src_feats = [[] for _ in range(len(example["src_feats"]))] - for subword, word_id in zip(example["src"], word_to_subword_mapping): - for i, feat_values in enumerate(example["src_feats"]): - inferred_feat = feat_values[word_id] - new_src_feats[i].append(inferred_feat) - example["src_feats"] = new_src_feats + if "src_feats" in example: + if self.reversible_tokenization == "joiner": + original_src = example["src_original"] + word_to_subword_mapping = subword_map_by_joiner( + example["src"], original_subwords=original_src) + else: # Spacer + word_to_subword_mapping = subword_map_by_spacer(example["src"]) + + new_src_feats = [[] for _ in range(len(example["src_feats"]))] + for subword, word_id in zip(example["src"], word_to_subword_mapping): + for i, feat_values in enumerate(example["src_feats"]): + inferred_feat = feat_values[word_id] + new_src_feats[i].append(inferred_feat) + example["src_feats"] = new_src_feats + + if "tgt_feats" in example: + if self.reversible_tokenization == "joiner": + original_tgt = example["tgt_original"] + word_to_subword_mapping = subword_map_by_joiner( + example["tgt"], original_subwords=original_tgt) + else: # Spacer + word_to_subword_mapping = subword_map_by_spacer(example["tgt"]) + + new_tgt_feats = [[] for _ in range(len(example["tgt_feats"]))] + for subword, word_id in zip(example["tgt"], word_to_subword_mapping): + for i, feat_values in enumerate(example["tgt_feats"]): + inferred_feat = feat_values[word_id] + new_tgt_feats[i].append(inferred_feat) + example["tgt_feats"] = new_tgt_feats return example diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 4f253ff2f3..8431d41297 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -58,11 +58,11 @@ class BeamSearchBase(DecodeStrategy): def __init__(self, beam_size, batch_size, pad, bos, eos, unk, start, n_best, global_scorer, min_length, max_length, return_attention, block_ngram_repeat, exclusion_tokens, - stepwise_penalty, ratio, ban_unk_token): + stepwise_penalty, ratio, ban_unk_token, n_tgt_feats): super(BeamSearchBase, self).__init__( pad, bos, eos, unk, start, batch_size, beam_size, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, - return_attention, max_length, ban_unk_token) + return_attention, max_length, ban_unk_token, n_tgt_feats) # beam parameters self.beam_size = beam_size self.n_best = n_best @@ -117,7 +117,7 @@ def initialize_(self, enc_out, src_len, src_map, device, @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, 0, -1] @property def current_backptr(self): @@ -152,6 +152,17 @@ def _pick(self, log_probs, out=None): topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) return topk_scores, topk_ids + def _pick_features(self, log_probs): + if len(log_probs) > 0: + features_id = [] + for probs in log_probs: + _, topk_ids = probs.topk(1, dim=-1) + features_id.append(topk_ids) + features_id = torch.cat(features_id, dim=-1) + return features_id + else: + return None + def update_finished(self): # Penalize beams that finished. _B_old = self.topk_log_probs.shape[0] @@ -161,7 +172,7 @@ def update_finished(self): # it's faster to not move this back to the original device self.is_finished = self.is_finished.to('cpu') self.top_beam_finished |= self.is_finished[:, 0].eq(1) - predictions = self.alive_seq.view(_B_old, self.beam_size, step) + predictions = self.alive_seq.view(_B_old, self.beam_size, -1, step) attention = ( self.alive_attn.view( step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) @@ -178,9 +189,12 @@ def update_finished(self): self.best_scores[b] = s self.hypotheses[b].append(( self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + predictions[i, j, 0, 1:], # Ignore start_token. attention[:, i, j, :self.src_len[i]] - if attention is not None else None)) + if attention is not None else None, + [predictions[i, 0, 1+k, 1:] + for k in range(self.n_tgt_feats)] + if predictions.size(-2) > 1 else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if self.ratio > 0: @@ -194,11 +208,12 @@ def update_finished(self): best_hyp = sorted( self.hypotheses[b], key=lambda x: x[0], reverse=True)[:self.n_best] - for n, (score, pred, attn) in enumerate(best_hyp): + for n, (score, pred, attn, feats) in enumerate(best_hyp): self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) + self.features[b].append(feats if feats is not None else []) else: non_finished_batch.append(i) @@ -224,7 +239,7 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, self._batch_index = self._batch_index.index_select(0, non_finished) self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions.index_select(0, non_finished) \ - .view(-1, self.alive_seq.size(-1)) + .view(-1, self.alive_seq.size(-2), self.alive_seq.size(-1)) self.topk_scores = self.topk_scores.index_select(0, non_finished) self.topk_ids = self.topk_ids.index_select(0, non_finished) self.maybe_update_target_prefix(self.select_indices) @@ -241,7 +256,11 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, self._prev_penalty = self._prev_penalty.index_select( 0, non_finished) - def advance(self, log_probs, attn): + def advance(self, log_probs, attn, feats_log_probs): + # Pick up candidates for target features + # we take top 1 for feats + features_ids = self._pick_features(feats_log_probs) + vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting @@ -285,10 +304,18 @@ def advance(self, log_probs, attn): self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids.fmod_(vocab_size) # resolve true word ids + # Concatenate topk_ids for tokens and feats. + if features_ids is not None: + topk_ids = torch.cat(( + self.topk_ids.view(_B * self.beam_size, 1), + features_ids), dim=1) + else: + topk_ids = self.topk_ids.view(_B * self.beam_size, 1) + # Append last prediction. self.alive_seq = torch.cat( [self.alive_seq.index_select(0, self.select_indices), - self.topk_ids.view(_B * self.beam_size, 1)], -1) + topk_ids.unsqueeze(-1)], -1) self.maybe_update_forbidden_tokens() diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index 7eeb4bb365..586d4dcbb6 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -68,7 +68,7 @@ class DecodeStrategy(object): def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, - ban_unk_token): + ban_unk_token, n_tgt_feats): # magic indices self.pad = pad @@ -86,6 +86,7 @@ def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, self.scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] self.hypotheses = [[] for _ in range(batch_size)] + self.features = [[] for _ in range(batch_size)] self.alive_attn = None @@ -102,6 +103,8 @@ def __init__(self, pad, bos, eos, unk, start, batch_size, parallel_paths, self.done = False + self.n_tgt_feats = n_tgt_feats + def get_device_from_enc_out(self, enc_out): if isinstance(enc_out, tuple): mb_device = enc_out[0].device @@ -140,8 +143,8 @@ def initialize(self, enc_out, src_len, src_map=None, device=None, device = torch.device('cpu') # Here we set the decoder to start with self.start (BOS or EOS) self.alive_seq = torch.full( - [self.batch_size * self.parallel_paths, 1], self.start, - dtype=torch.long, device=device) + [self.batch_size * self.parallel_paths, self.n_tgt_feats+1, 1], + self.start, dtype=torch.long, device=device) self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) @@ -161,13 +164,15 @@ def initialize(self, enc_out, src_len, src_map=None, device=None, return None, enc_out, src_len, src_map def __len__(self): - return self.alive_seq.shape[1] + return self.alive_seq.shape[-1] def ensure_min_length(self, log_probs): + # TODO: check if need in target_features probs if len(self) <= self.min_length: log_probs[:, self.eos] = -1e20 def ensure_unk_removed(self, log_probs): + # TODO: check if need in target_features probs if self.ban_unk_token: log_probs[:, self.unk] = -1e20 @@ -296,7 +301,7 @@ def maybe_update_target_prefix(self, select_index): return self.target_prefix = self.target_prefix.index_select(0, select_index) - def advance(self, log_probs, attn): + def advance(self, log_probs, attn, feats_log_probs): """DecodeStrategy subclasses should override :func:`advance()`. Advance is used to update ``self.alive_seq``, ``self.is_finished``, diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index ba51241f36..07d2e8d6e0 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -43,19 +43,28 @@ def _build_source_tokens(self, src): break return tokens - def _build_target_tokens(self, src, src_raw, pred, attn): + def _build_target_tokens(self, src, src_raw, pred, attn, feats): tokens = [] - for tok in pred: + if feats is not None: + pred_iter = zip(pred, *feats) + else: + pred_iter = [(item,) for item in pred] + + for tok, *tok_feats in pred_iter: if tok < len(self.vocabs['tgt']): - tokens.append(self.vocabs['tgt'].lookup_index(tok)) + token = self.vocabs['tgt'].lookup_index(tok) else: vl = len(self.vocabs['tgt']) - tokens.append(self.vocabs['src'].lookup_index(tok - vl)) - if tokens[-1] == DefaultTokens.EOS: - tokens = tokens[:-1] + token = self.vocabs['src'].lookup_index(tok - vl) + if token == DefaultTokens.EOS: break + if len(tok_feats) > 0: + for feat, fv in zip(tok_feats, self.vocabs['tgt_feats']): + token += "│" + fv.lookup_index(feat) + tokens.append(token) if self.replace_unk and attn is not None and src is not None: + assert False, "TODO" for i in range(len(tokens)): if tokens[i] == DefaultTokens.UNK: _, max_index = attn[i][:len(src_raw)].max(0) @@ -72,8 +81,9 @@ def from_batch(self, translation_batch): len(translation_batch["predictions"])) batch_size = len(batch['srclen']) - preds, pred_score, attn, align, gold_score, indices = list(zip( + preds, feats, pred_score, attn, align, gold_score, indices = list(zip( *sorted(zip(translation_batch["predictions"], + translation_batch["features"], translation_batch["scores"], translation_batch["attention"], translation_batch["alignment"], @@ -104,7 +114,8 @@ def from_batch(self, translation_batch): src[b, :] if src is not None else None, src_raw, preds[b][n], - align[b][n] if align[b] is not None else attn[b][n]) + align[b][n] if align[b] is not None else attn[b][n], + feats[b][n] if len(feats[0]) > 0 else None) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 895099a2bb..6485b52c7d 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -129,7 +129,8 @@ def __init__( logger=None, seed=-1, with_score=False, - decoder_start_token=DefaultTokens.BOS + decoder_start_token=DefaultTokens.BOS, + n_tgt_feats=0 ): self.model = model self.vocabs = vocabs @@ -208,6 +209,7 @@ def __init__( set_random_seed(seed, self._use_cuda) self.with_score = with_score + self.n_tgt_feats = n_tgt_feats @classmethod def from_opt( @@ -273,7 +275,8 @@ def from_opt( logger=logger, seed=opt.seed, with_score=opt.with_score, - decoder_start_token=opt.decoder_start_token + decoder_start_token=opt.decoder_start_token, + n_tgt_feats=opt.n_tgt_feats ) def _log(self, msg): @@ -551,17 +554,21 @@ def _decode_and_generate( else: attn = None - scores = self.model.generator(dec_out.squeeze(1)) + scores, feats_scores = self.model.generator(dec_out.squeeze(1)) log_probs = F.log_softmax(scores.to(torch.float32), dim=-1) + feats_log_probs = [F.log_softmax(s.to(torch.float32), dim=-1) + for s in feats_scores] # returns [(batch_size x beam_size) , vocab ] when 1 step # or [batch_size, tgt_len, vocab ] when full sentence else: attn = dec_attn["copy"] - scores = self.model.generator( + scores, feats_scores = self.model.generator( dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map, ) + # TODO: allow target feats inference with the copy mechanism + assert not feats_scores # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, len(batch['srclen']), @@ -581,7 +588,7 @@ def _decode_and_generate( log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [batch_size, tgt_len, vocab ] when full sentence - return log_probs, attn + return log_probs, attn, feats_log_probs def translate_batch(self, batch, attn_debug): """Translate a batch of sentences.""" @@ -603,16 +610,14 @@ def report_results( decode_strategy, ): results = { - "predictions": None, - "scores": None, - "attention": None, + "predictions": decode_strategy.predictions, + "scores": decode_strategy.scores, + "attention": decode_strategy.attention, + "features": decode_strategy.features, "batch": batch, "gold_score": gold_score, } - results["scores"] = decode_strategy.scores - results["predictions"] = decode_strategy.predictions - results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( batch, decode_strategy.predictions @@ -726,6 +731,7 @@ def translate_batch(self, batch, attn_debug): stepwise_penalty=self.stepwise_penalty, ratio=self.ratio, ban_unk_token=self.ban_unk_token, + n_tgt_feats=self.n_tgt_feats, ) return self._translate_batch_with_strategy( batch, decode_strategy @@ -803,11 +809,9 @@ def _translate_batch_with_strategy( # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): - # decoder_input = decode_strategy.current_predictions.view(1, -1, - # 1) decoder_input = decode_strategy.current_predictions.view(-1, 1, 1) - log_probs, attn = self._decode_and_generate( + log_probs, attn, feats_log_probs = self._decode_and_generate( decoder_input, enc_out, batch, @@ -817,7 +821,7 @@ def _translate_batch_with_strategy( batch_offset=decode_strategy.batch_offset, ) - decode_strategy.advance(log_probs, attn) + decode_strategy.advance(log_probs, attn, feats_log_probs) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() @@ -861,7 +865,7 @@ def _score_target( tgt = batch['tgt'] tgt_in = tgt[:, :-1, :] - log_probs, attn = self._decode_and_generate( + log_probs, attn, feats_log_probs = self._decode_and_generate( tgt_in, enc_out, batch, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index a274d53f82..c670715f62 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -11,6 +11,7 @@ from onmt.constants import ModelTask, DefaultTokens from onmt.modules.copy_generator import collapse_copy_scores from onmt.model_builder import load_test_model +from onmt.modules.criterions import Criterions try: import ctranslate2 except ImportError: @@ -36,13 +37,13 @@ class LossCompute(nn.Module): lm_prior_lambda (float): weight of LM model in loss lm_prior_tau (float): scaler for LM loss """ - def __init__(self, criterion, generator, + def __init__(self, criterions, generator, copy_attn=False, lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, vocab=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, lm_prior_model=None): super(LossCompute, self).__init__() - self.criterion = criterion + self.criterions = criterions self.generator = generator self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align @@ -55,7 +56,7 @@ def __init__(self, criterion, generator, self.lm_prior_model = lm_prior_model @classmethod - def from_opts(cls, opt, model, vocab, train=True): + def from_opts(cls, opt, model, vocabs, train=True): """ Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -66,8 +67,7 @@ def from_opts(cls, opt, model, vocab, train=True): device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - padding_idx = vocab[DefaultTokens.PAD] - unk_idx = vocab[DefaultTokens.UNK] + tgt_vocab = vocabs['tgt'] if opt.lambda_coverage != 0: assert opt.coverage_attn, "--coverage_attn needs to be set in " \ @@ -75,21 +75,7 @@ def from_opts(cls, opt, model, vocab, train=True): tgt_shift_idx = 1 if opt.model_task == ModelTask.SEQ2SEQ else 0 - if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( - len(vocab), opt.copy_attn_force, - unk_index=unk_idx, ignore_index=padding_idx - ) - else: - if opt.generator_function == 'sparsemax': - criterion = SparsemaxLoss(ignore_index=padding_idx, - reduction='sum') - else: - criterion = nn.CrossEntropyLoss( - ignore_index=padding_idx, - reduction='sum', - label_smoothing=opt.label_smoothing - ) + criterions = Criterions(opt, vocabs) lm_prior_lambda = opt.lm_prior_lambda lm_prior_tau = opt.lm_prior_tau @@ -116,12 +102,12 @@ def from_opts(cls, opt, model, vocab, train=True): lm_generator = None lm_prior_model = None - compute = cls(criterion, model.generator, + compute = cls(criterions, model.generator, copy_attn=opt.copy_attn, lambda_coverage=opt.lambda_coverage, lambda_align=opt.lambda_align, tgt_shift_index=tgt_shift_idx, - vocab=vocab, lm_generator=lm_generator, + vocab=tgt_vocab, lm_generator=lm_generator, lm_prior_lambda=lm_prior_lambda, lm_prior_tau=lm_prior_tau, lm_prior_model=lm_prior_model) @@ -131,7 +117,7 @@ def from_opts(cls, opt, model, vocab, train=True): @property def padding_idx(self): - return self.criterion.ignore_index + return self.criterions.tgt_criterion.ignore_index def _compute_coverage_loss(self, std_attn, cov_attn, tgt): """compute coverage loss""" @@ -164,12 +150,12 @@ def _compute_copy_loss(self, batch, output, target, align, attns): Returns: A tuple with the loss and raw scores. """ - scores = self.generator(self._bottle(output), - self._bottle(attns['copy']), - batch['src_map']) - loss = self.criterion(scores, align, target).sum() + scores, feats_scores = self.generator(self._bottle(output), + self._bottle(attns['copy']), + batch['src_map']) + loss = self.criterions.tgt_criterion(scores, align, target).sum() - return loss, scores + return loss, scores, feats_scores def _compute_lm_loss_ct2(self, output, target): """ @@ -277,8 +263,8 @@ def forward(self, batch, output, attns, align = batch['alignment'][ :, trunc_range[0]:trunc_range[1] ].contiguous().view(-1) - loss, scores = self._compute_copy_loss(batch, output, flat_tgt, - align, attns) + loss, scores, feats_scores = \ + self._compute_copy_loss(batch, output, flat_tgt, align, attns) scores_data = collapse_copy_scores( self._unbottle(scores.clone(), len(batch['srclen'])), batch, self.vocab, None) @@ -287,7 +273,7 @@ def forward(self, batch, output, attns, # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = flat_tgt.clone() - unk = self.criterion.unk_index + unk = self.criterions.tgt_criterion.unk_index correct_mask = (target_data == unk) & (align != unk) offset_align = align[correct_mask] + len(self.vocab) target_data[correct_mask] += offset_align @@ -296,10 +282,12 @@ def forward(self, batch, output, attns, else: - scores = self.generator(self._bottle(output)) - if isinstance(self.criterion, SparsemaxLoss): + scores, feats_scores = self.generator(self._bottle(output)) + if isinstance(self.criterions.tgt_criterion, SparsemaxLoss): scores = LogSparsemax(scores.to(torch.float32), dim=-1) - loss = self.criterion(scores.to(torch.float32), flat_tgt) + + loss = self.criterions.tgt_criterion( + scores.to(torch.float32), flat_tgt) if self.lambda_align != 0.0: align_head = attns['align'] @@ -319,6 +307,15 @@ def forward(self, batch, output, attns, align_head=align_head, ref_align=ref_align) loss += align_loss + # Compute target features losses + assert len(feats_scores) == \ + len(self.criterions.feats_criterions) # Security check + for i, (feat_scores, criterion) in enumerate( + zip(feats_scores, self.criterions.feats_criterions)): + loss += criterion( + feat_scores.to(torch.float32), + target[:, :, i+1].contiguous().view(-1)) + if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( attns['std'], attns['coverage'], flat_tgt) @@ -332,6 +329,7 @@ def forward(self, batch, output, attns, lm_loss = self._compute_lm_loss(output, batch['tgt']) loss = loss + lm_loss * self.lm_prior_lambda + # TODO: pass feat scores to stats stats = self._stats(len(batch['srclen']), loss.sum().item(), scores, flat_tgt) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 29483043d3..795e76d48a 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -77,11 +77,11 @@ def _validate_data(cls, opt): corpus['weight'] = 1 # Check features - if opt.n_src_feats > 0: + if opt.n_src_feats > 0 or opt.n_tgt_feats > 0: if 'inferfeats' not in corpus["transforms"]: raise ValueError( "'inferfeats' transform is required " - "when setting source features") + "when setting source or target features") logger.info(f"Parsed {len(corpora)} corpora from -data.") opt.data = corpora