diff --git a/onmt/model_builder.py b/onmt/model_builder.py index ff6ef63e87..be99e762ba 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -3,6 +3,7 @@ and creates each encoder and decoder accordingly. """ import re +from functools import partial import torch import torch.nn as nn from torch.nn.init import xavier_uniform_ @@ -14,6 +15,7 @@ from onmt.decoders import str2dec from onmt.modules import Embeddings, VecEmbedding, CopyGenerator +from onmt.modules.sparse_activations import LogSparsemax, LogEntmax15 from onmt.modules.util_class import Cast from onmt.utils.misc import use_gpu from onmt.utils.logging import logger @@ -173,10 +175,11 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): # Build Generator. if not model_opt.copy_attn: - if model_opt.generator_function == "sparsemax": - gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) - else: - gen_func = nn.LogSoftmax(dim=-1) + gen_funcs = {"softmax": nn.LogSoftmax, + "sparsemax": partial(LogSparsemax, k=512), + "entmax15": partial(LogEntmax15, k=512)} + assert model_opt.generator_function in gen_funcs + gen_func = gen_funcs[model_opt.generator_function](dim=-1) generator = nn.Sequential( nn.Linear(model_opt.dec_rnn_size, len(fields["tgt"].base_field.vocab)), diff --git a/onmt/modules/global_attention.py b/onmt/modules/global_attention.py index 74a085e077..98995a318d 100644 --- a/onmt/modules/global_attention.py +++ b/onmt/modules/global_attention.py @@ -1,9 +1,8 @@ """Global attention modules (Luong / Bahdanau)""" import torch import torch.nn as nn -import torch.nn.functional as F -from onmt.modules.sparse_activations import sparsemax +from entmax import Sparsemax, Entmax15 from onmt.utils.misc import aeq, sequence_mask # This class is mainly used by decoder.py for RNNs but also @@ -77,9 +76,12 @@ def __init__(self, dim, coverage=False, attn_type="dot", "Please select a valid attention type (got {:s}).".format( attn_type)) self.attn_type = attn_type - assert attn_func in ["softmax", "sparsemax"], ( - "Please select a valid attention function.") - self.attn_func = attn_func + attn_funcs = {"softmax": nn.Softmax, + "sparsemax": Sparsemax, + "entmax15": Entmax15} + assert attn_func in attn_funcs, \ + "Unknown attention function {}".format(attn_func) + self.attn_func = attn_funcs[attn_func](dim=-1) if self.attn_type == "general": self.linear_in = nn.Linear(dim, dim, bias=False) @@ -135,11 +137,11 @@ def score(self, h_t, h_s): return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) - def forward(self, source, memory_bank, memory_lengths=None, coverage=None): + def forward(self, query, memory_bank, memory_lengths=None, coverage=None): """ Args: - source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` + query (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) @@ -152,22 +154,20 @@ def forward(self, source, memory_bank, memory_lengths=None, coverage=None): ``(tgt_len, batch, src_len)`` """ - # one step input - if source.dim() == 2: - one_step = True - source = source.unsqueeze(1) - else: - one_step = False + one_step = query.dim() == 2 + if one_step: + # compute attention for only one tgt step on this call + query = query.unsqueeze(1) - batch, source_l, dim = memory_bank.size() - batch_, target_l, dim_ = source.size() - aeq(batch, batch_) - aeq(dim, dim_) - aeq(self.dim, dim) + batch, source_l, memory_dim = memory_bank.size() + query_batch, target_l, query_dim = query.size() + aeq(batch, query_batch) + aeq(self.dim, memory_dim, query_dim) if coverage is not None: batch_, source_l_ = coverage.size() - aeq(batch, batch_) - aeq(source_l, source_l_) + coverage_batch, coverage_l = coverage.size() + aeq(batch, coverage_batch) + aeq(source_l, coverage_l) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) @@ -175,18 +175,16 @@ def forward(self, source, memory_bank, memory_lengths=None, coverage=None): memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. - align = self.score(source, memory_bank) + align = self.score(query, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) - # Softmax or sparsemax to normalize attention weights - if self.attn_func == "softmax": - align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) - else: - align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) + n_vectors = batch * target_l + # normalize attention weights + align_vectors = self.attn_func(align.view(n_vectors, source_l)) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average @@ -194,34 +192,27 @@ def forward(self, source, memory_bank, memory_lengths=None, coverage=None): c = torch.bmm(align_vectors, memory_bank) # concatenate - concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) - attn_h = self.linear_out(concat_c).view(batch, target_l, dim) + concat_c = torch.cat([c, query], 2).view(n_vectors, self.dim * 2) + attn_h = self.linear_out(concat_c).view(batch, target_l, self.dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) - # Check output sizes - batch_, dim_ = attn_h.size() - aeq(batch, batch_) - aeq(dim, dim_) - batch_, source_l_ = align_vectors.size() - aeq(batch, batch_) - aeq(source_l, source_l_) - + align_batch, align_src_l = align_vectors.size() + attn_batch, attn_dim = attn_h.size() else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes - target_l_, batch_, dim_ = attn_h.size() - aeq(target_l, target_l_) - aeq(batch, batch_) - aeq(dim, dim_) - target_l_, batch_, source_l_ = align_vectors.size() - aeq(target_l, target_l_) - aeq(batch, batch_) - aeq(source_l, source_l_) + attn_l, attn_batch, attn_dim = attn_h.size() + align_tgt_l, align_batch, align_src_l = align_vectors.size() + aeq(target_l, attn_l, align_tgt_l) + + aeq(batch, attn_batch, align_batch) + aeq(source_l, align_src_l) + aeq(self.dim, attn_dim) return attn_h, align_vectors diff --git a/onmt/modules/sparse_activations.py b/onmt/modules/sparse_activations.py index b301d3adab..7a8addf80d 100644 --- a/onmt/modules/sparse_activations.py +++ b/onmt/modules/sparse_activations.py @@ -1,97 +1,12 @@ -""" -An implementation of sparsemax (Martins & Astudillo, 2016). See -:cite:`DBLP:journals/corr/MartinsA16` for detailed description. - -By Ben Peters and Vlad Niculae -""" - import torch -from torch.autograd import Function -import torch.nn as nn - - -def _make_ix_like(input, dim=0): - d = input.size(dim) - rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) - view = [1] * input.dim() - view[0] = -1 - return rho.view(view).transpose(0, dim) - - -def _threshold_and_support(input, dim=0): - """Sparsemax building block: compute the threshold - - Args: - input: any dimension - dim: dimension along which to apply the sparsemax - - Returns: - the threshold value - """ - - input_srt, _ = torch.sort(input, descending=True, dim=dim) - input_cumsum = input_srt.cumsum(dim) - 1 - rhos = _make_ix_like(input, dim) - support = rhos * input_srt > input_cumsum - - support_size = support.sum(dim=dim).unsqueeze(dim) - tau = input_cumsum.gather(dim, support_size - 1) - tau /= support_size.to(input.dtype) - return tau, support_size - - -class SparsemaxFunction(Function): - - @staticmethod - def forward(ctx, input, dim=0): - """sparsemax: normalizing sparse transform (a la softmax) - - Parameters: - input (Tensor): any shape - dim: dimension along which to apply sparsemax - - Returns: - output (Tensor): same shape as input - """ - ctx.dim = dim - max_val, _ = input.max(dim=dim, keepdim=True) - input -= max_val # same numerical stability trick as for softmax - tau, supp_size = _threshold_and_support(input, dim=dim) - output = torch.clamp(input - tau, min=0) - ctx.save_for_backward(supp_size, output) - return output - - @staticmethod - def backward(ctx, grad_output): - supp_size, output = ctx.saved_tensors - dim = ctx.dim - grad_input = grad_output.clone() - grad_input[output == 0] = 0 - - v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() - v_hat = v_hat.unsqueeze(dim) - grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) - return grad_input, None - - -sparsemax = SparsemaxFunction.apply - - -class Sparsemax(nn.Module): - - def __init__(self, dim=0): - self.dim = dim - super(Sparsemax, self).__init__() - - def forward(self, input): - return sparsemax(input, self.dim) +from entmax import Entmax15, Sparsemax -class LogSparsemax(nn.Module): +class LogSparsemax(Sparsemax): + def forward(self, *args, **kwargs): + return torch.log(super().forward(*args, **kwargs)) - def __init__(self, dim=0): - self.dim = dim - super(LogSparsemax, self).__init__() - def forward(self, input): - return torch.log(sparsemax(input, self.dim)) +class LogEntmax15(Entmax15): + def forward(self, *args, **kwargs): + return torch.log(super().forward(*args, **kwargs)) diff --git a/onmt/modules/sparse_losses.py b/onmt/modules/sparse_losses.py deleted file mode 100644 index b98dc5c11a..0000000000 --- a/onmt/modules/sparse_losses.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn as nn -from torch.autograd import Function -from onmt.modules.sparse_activations import _threshold_and_support -from onmt.utils.misc import aeq - - -class SparsemaxLossFunction(Function): - - @staticmethod - def forward(ctx, input, target): - """ - input (FloatTensor): ``(n, num_classes)``. - target (LongTensor): ``(n,)``, the indices of the target classes - """ - input_batch, classes = input.size() - target_batch = target.size(0) - aeq(input_batch, target_batch) - - z_k = input.gather(1, target.unsqueeze(1)).squeeze() - tau_z, support_size = _threshold_and_support(input, dim=1) - support = input > tau_z - x = torch.where( - support, input**2 - tau_z**2, - torch.tensor(0.0, device=input.device) - ).sum(dim=1) - ctx.save_for_backward(input, target, tau_z) - # clamping necessary because of numerical errors: loss should be lower - # bounded by zero, but negative values near zero are possible without - # the clamp - return torch.clamp(x / 2 - z_k + 0.5, min=0.0) - - @staticmethod - def backward(ctx, grad_output): - input, target, tau_z = ctx.saved_tensors - sparsemax_out = torch.clamp(input - tau_z, min=0) - delta = torch.zeros_like(sparsemax_out) - delta.scatter_(1, target.unsqueeze(1), 1) - return sparsemax_out - delta, None - - -sparsemax_loss = SparsemaxLossFunction.apply - - -class SparsemaxLoss(nn.Module): - """ - An implementation of sparsemax loss, first proposed in - :cite:`DBLP:journals/corr/MartinsA16`. If using - a sparse output layer, it is not possible to use negative log likelihood - because the loss is infinite in the case the target is assigned zero - probability. Inputs to SparsemaxLoss are arbitrary dense real-valued - vectors (like in nn.CrossEntropyLoss), not probability vectors (like in - nn.NLLLoss). - """ - - def __init__(self, weight=None, ignore_index=-100, - reduction='elementwise_mean'): - assert reduction in ['elementwise_mean', 'sum', 'none'] - self.reduction = reduction - self.weight = weight - self.ignore_index = ignore_index - super(SparsemaxLoss, self).__init__() - - def forward(self, input, target): - loss = sparsemax_loss(input, target) - if self.ignore_index >= 0: - ignored_positions = target == self.ignore_index - size = float((target.size(0) - ignored_positions.sum()).item()) - loss.masked_fill_(ignored_positions, 0.0) - else: - size = float(target.size(0)) - if self.reduction == 'sum': - loss = loss.sum() - elif self.reduction == 'elementwise_mean': - loss = loss.sum() / size - return loss diff --git a/onmt/opts.py b/onmt/opts.py index cea4a37bc6..77d525900f 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -136,7 +136,8 @@ def model_opts(parser): help="The attention type to use: " "dotprod or general (Luong) or MLP (Bahdanau)") group.add('--global_attention_function', '-global_attention_function', - type=str, default="softmax", choices=["softmax", "sparsemax"]) + type=str, default="softmax", + choices=["softmax", "sparsemax", "entmax15"]) group.add('--self_attn_type', '-self_attn_type', type=str, default="scaled-dot", help='Self attention type in Transformer decoder ' @@ -163,7 +164,7 @@ def model_opts(parser): help="The copy attention type to use. Leave as None to use " "the same as -global_attention.") group.add('--generator_function', '-generator_function', default="softmax", - choices=["softmax", "sparsemax"], + choices=["softmax", "sparsemax", "entmax15"], help="Which function to use for generating " "probabilities over the target vocabulary (choices: " "softmax, sparsemax)") diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 5148b88847..a954a29b04 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -3,13 +3,13 @@ sharded loss compute stuff. """ from __future__ import division +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import onmt -from onmt.modules.sparse_losses import SparsemaxLoss -from onmt.modules.sparse_activations import LogSparsemax +from entmax import SparsemaxLoss, Entmax15Loss def build_loss_compute(model, tgt_field, opt, train=True): @@ -39,16 +39,19 @@ def build_loss_compute(model, tgt_field, opt, train=True): criterion = LabelSmoothingLoss( opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx ) - elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + # future idea: use nn.CrossEntropyLoss for softmax, to parallel entmax + loss_funcs = {"softmax": nn.NLLLoss, + "sparsemax": partial(SparsemaxLoss, k=512), + "entmax15": partial(Entmax15Loss, k=512)} + assert opt.generator_function in loss_funcs + loss_func = loss_funcs[opt.generator_function] + criterion = loss_func(ignore_index=padding_idx, reduction='sum') # if the loss function operates on vectors of raw logits instead of - # probabilities, only the first part of the generator needs to be - # passed to the NMTLossCompute. At the moment, the only supported - # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) + # probabilities, the second part of the generator is not passed to + # NMTLossCompute. This is true for sparsemax and entmax losses. + use_raw_logits = opt.generator_function != "softmax" loss_gen = model.generator[0] if use_raw_logits else model.generator if opt.copy_attn: compute = onmt.modules.CopyGeneratorLossCompute( diff --git a/requirements.txt b/requirements.txt index ab66f3f280..9f6b07e6b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ torch>=1.1 git+https://github.com/pytorch/text.git@master#wheel=torchtext future configargparse +entmax>=1.0