diff --git a/onmt/models/model.py b/onmt/models/model.py
index 920adcc981..a3ce348413 100644
--- a/onmt/models/model.py
+++ b/onmt/models/model.py
@@ -1,5 +1,6 @@
 """ Onmt NMT Model base class definition """
 import torch.nn as nn
+import torch
 
 
 class NMTModel(nn.Module):
@@ -17,7 +18,8 @@ def __init__(self, encoder, decoder):
         self.encoder = encoder
         self.decoder = decoder
 
-    def forward(self, src, tgt, lengths, bptt=False, with_align=False):
+    def forward(self, src, tgt, lengths, bptt=False,
+                with_align=False, encode_tgt=False):
         """Forward propagate a `src` and `tgt` pair for training.
         Possible initialized with a beginning decoder state.
 
@@ -46,9 +48,20 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False):
 
         if bptt is False:
             self.decoder.init_state(src, memory_bank, enc_state)
+
         dec_out, attns = self.decoder(dec_in, memory_bank,
                                       memory_lengths=lengths,
                                       with_align=with_align)
+
+        if encode_tgt:
+            # tgt for zero shot alignment loss
+            tgt_lengths = torch.Tensor(tgt.size(1))\
+                               .type_as(memory_bank) \
+                               .long() \
+                               .fill_(tgt.size(0))
+            embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths)
+            return dec_out, attns, memory_bank, memory_bank_tgt
+
         return dec_out, attns
 
     def update_dropout(self, dropout):
diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py
index 900096cf4d..b5959ce92e 100644
--- a/onmt/modules/copy_generator.py
+++ b/onmt/modules/copy_generator.py
@@ -186,14 +186,15 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
         self.tgt_vocab = tgt_vocab
         self.normalize_by_length = normalize_by_length
 
-    def _make_shard_state(self, batch, output, range_, attns):
+    def _make_shard_state(self, batch, output, enc_src, enc_tgt,
+                          range_, attns):
         """See base class for args description."""
         if getattr(batch, "alignment", None) is None:
             raise AssertionError("using -copy_attn you need to pass in "
                                  "-dynamic_dict during preprocess stage.")
 
         shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state(
-            batch, output, range_, attns)
+            batch, output, enc_src, enc_tgt, range_, attns)
 
         shard_state.update({
             "copy_attn": attns.get("copy"),
@@ -201,7 +202,8 @@ def _make_shard_state(self, batch, output, range_, attns):
         })
         return shard_state
 
-    def _compute_loss(self, batch, output, target, copy_attn, align,
+    def _compute_loss(self, batch, normalization, output, target,
+                      copy_attn, align, enc_src=None, enc_tgt=None,
                       std_attn=None, coverage_attn=None):
         """Compute the loss.
 
@@ -244,8 +246,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
         offset_align = align[correct_mask] + len(self.tgt_vocab)
         target_data[correct_mask] += offset_align
 
+        if self.lambda_cosine != 0.0:
+            cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
+            loss += self.lambda_cosine * (cosine_loss / num_ex)
+        else:
+            cosine_loss = None
+            num_ex = 0
+
         # Compute sum of perplexities for stats
-        stats = self._stats(loss.sum().clone(), scores_data, target_data)
+        stats = self._stats(loss.sum().clone(),
+                            cosine_loss.clone() if cosine_loss is not None
+                            else cosine_loss,
+                            scores_data, target_data, num_ex)
 
         # this part looks like it belongs in CopyGeneratorLoss
         if self.normalize_by_length:
diff --git a/onmt/opts.py b/onmt/opts.py
index e61724b26c..191fc5c151 100644
--- a/onmt/opts.py
+++ b/onmt/opts.py
@@ -193,6 +193,9 @@ def model_opts(parser):
               help='Train a coverage attention layer.')
     group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0,
               help='Lambda value for coverage loss of See et al (2017)')
+    group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0,
+              help='Lambda value for cosine alignment loss '
+                   'of https://arxiv.org/abs/1903.07091 ')
     group.add('--loss_scale', '-loss_scale', type=float, default=0,
               help="For FP16 training, the static loss scale to use. If not "
                    "set, the loss scale is dynamically computed.")
diff --git a/onmt/trainer.py b/onmt/trainer.py
index 4328ca52ea..d60c1a0ddb 100644
--- a/onmt/trainer.py
+++ b/onmt/trainer.py
@@ -70,7 +70,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
                            model_dtype=opt.model_dtype,
                            earlystopper=earlystopper,
                            dropout=dropout,
-                           dropout_steps=dropout_steps)
+                           dropout_steps=dropout_steps,
+                           encode_tgt=True if opt.lambda_cosine > 0 else False)
     return trainer
 
 
@@ -107,7 +108,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
                  n_gpu=1, gpu_rank=1, gpu_verbose_level=0,
                  report_manager=None, with_align=False, model_saver=None,
                  average_decay=0, average_every=1, model_dtype='fp32',
-                 earlystopper=None, dropout=[0.3], dropout_steps=[0]):
+                 earlystopper=None, dropout=[0.3], dropout_steps=[0],
+                 encode_tgt=False):
         # Basic attributes.
         self.model = model
         self.train_loss = train_loss
@@ -132,6 +134,7 @@ def __init__(self, model, train_loss, valid_loss, optim,
         self.earlystopper = earlystopper
         self.dropout = dropout
         self.dropout_steps = dropout_steps
+        self.encode_tgt = encode_tgt
 
         for i in range(len(self.accum_count_l)):
             assert self.accum_count_l[i] > 0
@@ -314,11 +317,21 @@ def validate(self, valid_iter, moving_average=None):
                 tgt = batch.tgt
 
                 # F-prop through the model.
-                outputs, attns = valid_model(src, tgt, src_lengths,
-                                             with_align=self.with_align)
+                if self.encode_tgt:
+                    outputs, attns, enc_src, enc_tgt = valid_model(
+                        src, tgt, src_lengths,
+                        with_align=self.with_align,
+                        encode_tgt=self.encode_tgt)
+                else:
+                    outputs, attns = valid_model(
+                        src, tgt, src_lengths,
+                        with_align=self.with_align)
+                    enc_src, enc_tgt = None, None
 
                 # Compute loss.
-                _, batch_stats = self.valid_loss(batch, outputs, attns)
+                _, batch_stats = self.valid_loss(
+                    batch, outputs, attns,
+                    enc_src=enc_src, enc_tgt=enc_tgt)
 
                 # Update statistics.
                 stats.update(batch_stats)
@@ -361,8 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
                 if self.accum_count == 1:
                     self.optim.zero_grad()
 
-                outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
-                                            with_align=self.with_align)
+                if self.encode_tgt:
+                    outputs, attns, enc_src, enc_tgt = self.model(
+                        src, tgt, src_lengths, bptt=bptt,
+                        with_align=self.with_align, encode_tgt=self.encode_tgt)
+                else:
+                    outputs, attns = self.model(
+                        src, tgt, src_lengths, bptt=bptt,
+                        with_align=self.with_align)
+                    enc_src, enc_tgt = None, None
+
                 bptt = True
 
                 # 3. Compute loss.
@@ -371,6 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
                         batch,
                         outputs,
                         attns,
+                        enc_src=enc_src,
+                        enc_tgt=enc_tgt,
                         normalization=normalization,
                         shard_size=self.shard_size,
                         trunc_start=j,
diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py
index c48f0d3d21..f185e1a567 100644
--- a/onmt/utils/loss.py
+++ b/onmt/utils/loss.py
@@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True):
     else:
         compute = NMTLossCompute(
             criterion, loss_gen, lambda_coverage=opt.lambda_coverage,
-            lambda_align=opt.lambda_align)
+            lambda_align=opt.lambda_align, lambda_cosine=opt.lambda_cosine)
     compute.to(device)
 
     return compute
@@ -92,7 +92,8 @@ def __init__(self, criterion, generator):
     def padding_idx(self):
         return self.criterion.ignore_index
 
-    def _make_shard_state(self, batch, output, range_, attns=None):
+    def _make_shard_state(self, batch, enc_src, enc_tgt,
+                          output, range_, attns=None):
         """
         Make shard state dictionary for shards() to return iterable
         shards for efficient loss computation. Subclass must define
@@ -123,6 +124,8 @@ def __call__(self,
                  batch,
                  output,
                  attns,
+                 enc_src=None,
+                 enc_tgt=None,
                  normalization=1.0,
                  shard_size=0,
                  trunc_start=0,
@@ -157,18 +160,20 @@ def __call__(self,
         if trunc_size is None:
             trunc_size = batch.tgt.size(0) - trunc_start
         trunc_range = (trunc_start, trunc_start + trunc_size)
-        shard_state = self._make_shard_state(batch, output, trunc_range, attns)
+        shard_state = self._make_shard_state(
+            batch, output, enc_src, enc_tgt, trunc_range, attns)
         if shard_size == 0:
-            loss, stats = self._compute_loss(batch, **shard_state)
-            return loss / float(normalization), stats
+            loss, stats = self._compute_loss(batch, normalization,
+                                             **shard_state)
+            return loss, stats
         batch_stats = onmt.utils.Statistics()
         for shard in shards(shard_state, shard_size):
-            loss, stats = self._compute_loss(batch, **shard)
-            loss.div(float(normalization)).backward()
+            loss, stats = self._compute_loss(batch, normalization, **shard)
+            loss.backward()
             batch_stats.update(stats)
         return None, batch_stats
 
-    def _stats(self, loss, scores, target):
+    def _stats(self, loss, cosine_loss, scores, target, num_ex):
         """
         Args:
             loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
@@ -182,7 +187,9 @@ def _stats(self, loss, scores, target):
         non_padding = target.ne(self.padding_idx)
         num_correct = pred.eq(target).masked_select(non_padding).sum().item()
         num_non_padding = non_padding.sum().item()
-        return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct)
+        return onmt.utils.Statistics(
+            loss.item(), cosine_loss.item() if cosine_loss is not None else 0,
+            num_non_padding, num_correct, num_ex)
 
     def _bottle(self, _v):
         return _v.view(-1, _v.size(2))
@@ -227,15 +234,17 @@ class NMTLossCompute(LossComputeBase):
     """
 
     def __init__(self, criterion, generator, normalization="sents",
-                 lambda_coverage=0.0, lambda_align=0.0):
+                 lambda_coverage=0.0, lambda_align=0.0, lambda_cosine=0.0):
         super(NMTLossCompute, self).__init__(criterion, generator)
         self.lambda_coverage = lambda_coverage
         self.lambda_align = lambda_align
+        self.lambda_cosine = lambda_cosine
 
-    def _make_shard_state(self, batch, output, range_, attns=None):
+    def _make_shard_state(self, batch, output, enc_src, enc_tgt,
+                          range_, attns=None):
         shard_state = {
             "output": output,
-            "target": batch.tgt[range_[0] + 1: range_[1], :, 0],
+            "target": batch.tgt[range_[0] + 1: range_[1], :, 0]
         }
         if self.lambda_coverage != 0.0:
             coverage = attns.get("coverage", None)
@@ -273,9 +282,15 @@ def _make_shard_state(self, batch, output, range_, attns=None):
                 "align_head": attn_align,
                 "ref_align": ref_align[:, range_[0] + 1: range_[1], :]
             })
+        if self.lambda_cosine != 0.0:
+            shard_state.update({
+                "enc_src": enc_src,
+                "enc_tgt": enc_tgt
+                })
         return shard_state
 
-    def _compute_loss(self, batch, output, target, std_attn=None,
+    def _compute_loss(self, batch, normalization, output, target,
+                      enc_src=None, enc_tgt=None, std_attn=None,
                       coverage_attn=None, align_head=None, ref_align=None):
 
         bottled_output = self._bottle(output)
@@ -284,6 +299,7 @@ def _compute_loss(self, batch, output, target, std_attn=None,
         gtruth = target.view(-1)
 
         loss = self.criterion(scores, gtruth)
+
         if self.lambda_coverage != 0.0:
             coverage_loss = self._compute_coverage_loss(
                 std_attn=std_attn, coverage_attn=coverage_attn)
@@ -296,7 +312,20 @@ def _compute_loss(self, batch, output, target, std_attn=None,
             align_loss = self._compute_alignement_loss(
                 align_head=align_head, ref_align=ref_align)
             loss += align_loss
-        stats = self._stats(loss.clone(), scores, gtruth)
+
+        loss = loss/float(normalization)
+
+        if self.lambda_cosine != 0.0:
+            cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
+            loss += self.lambda_cosine * (cosine_loss / num_ex)
+        else:
+            cosine_loss = None
+            num_ex = 0
+
+        stats = self._stats(loss.clone() * normalization,
+                            cosine_loss.clone() if cosine_loss is not None
+                            else cosine_loss,
+                            scores, gtruth, num_ex)
 
         return loss, stats
 
@@ -305,6 +334,15 @@ def _compute_coverage_loss(self, std_attn, coverage_attn):
         covloss *= self.lambda_coverage
         return covloss
 
+    def _compute_cosine_loss(self, enc_src, enc_tgt):
+        max_src = enc_src.max(axis=0)[0]
+        max_tgt = enc_tgt.max(axis=0)[0]
+        cosine_loss = torch.nn.functional.cosine_similarity(
+            max_src.float(), max_tgt.float(), dim=1)
+        cosine_loss = 1 - cosine_loss
+        num_ex = cosine_loss.size(0)
+        return cosine_loss.sum(), num_ex
+
     def _compute_alignement_loss(self, align_head, ref_align):
         """Compute loss between 2 partial alignment matrix."""
         # align_head contains value in [0, 1) presenting attn prob,
@@ -368,7 +406,7 @@ def shards(state, shard_size, eval_only=False):
         # over the shards, not over the keys: therefore, the values need
         # to be re-zipped by shard and then each shard can be paired
         # with the keys.
-        for shard_tensors in zip(*values):
+        for i, shard_tensors in enumerate(zip(*values)):
             yield dict(zip(keys, shard_tensors))
 
         # Assumed backprop'd
diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py
index 273dae3dba..ac6ddf6820 100644
--- a/onmt/utils/parse.py
+++ b/onmt/utils/parse.py
@@ -120,6 +120,10 @@ def validate_train_opts(cls, opt):
         assert len(opt.attention_dropout) == len(opt.dropout_steps), \
             "Number of attention_dropout values must match accum_steps values"
 
+        assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \
+            "-lambda_cosine loss is not implemented " \
+            "for max_generator_batches > 0."
+
     @classmethod
     def validate_translate_opts(cls, opt):
         if opt.beam_size != 1 and opt.random_sampling_topk != 1:
diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py
index 896d98c74d..87a1e7f8f1 100644
--- a/onmt/utils/statistics.py
+++ b/onmt/utils/statistics.py
@@ -17,12 +17,15 @@ class Statistics(object):
     * elapsed time
     """
 
-    def __init__(self, loss=0, n_words=0, n_correct=0):
+    def __init__(self, loss=0, cosine_loss=0, n_words=0,
+                 n_correct=0, num_ex=0):
         self.loss = loss
         self.n_words = n_words
         self.n_correct = n_correct
         self.n_src_words = 0
         self.start_time = time.time()
+        self.cosine_loss = cosine_loss
+        self.num_ex = num_ex
 
     @staticmethod
     def all_gather_stats(stat, max_size=4096):
@@ -81,6 +84,8 @@ def update(self, stat, update_n_src_words=False):
         self.loss += stat.loss
         self.n_words += stat.n_words
         self.n_correct += stat.n_correct
+        self.cosine_loss += stat.cosine_loss
+        self.num_ex += stat.num_ex
 
         if update_n_src_words:
             self.n_src_words += stat.n_src_words
@@ -97,6 +102,10 @@ def ppl(self):
         """ compute perplexity """
         return math.exp(min(self.loss / self.n_words, 100))
 
+    def cos(self):
+        """ normalize cosine distance per example"""
+        return self.cosine_loss / self.num_ex
+
     def elapsed_time(self):
         """ compute elapsed time """
         return time.time() - self.start_time
@@ -113,8 +122,12 @@ def output(self, step, num_steps, learning_rate, start):
         step_fmt = "%2d" % step
         if num_steps > 0:
             step_fmt = "%s/%5d" % (step_fmt, num_steps)
+        if self.cosine_loss != 0:
+            cos_log = "cos: %4.2f; " % (self.cos())
+        else:
+            cos_log = ""
         logger.info(
-            ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
+            ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + cos_log +
              "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec")
             % (step_fmt,
                self.accuracy(),