diff --git a/effdet/bench.py b/effdet/bench.py
index a35df694..ff0e024c 100644
--- a/effdet/bench.py
+++ b/effdet/bench.py
@@ -86,7 +86,7 @@ def __init__(self, model, config):
             config.num_scales, config.aspect_ratios,
             config.anchor_scale, config.image_size)
         self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
-        self.loss_fn = DetectionLoss(self.config)
+        self.loss_fn = DetectionLoss(self.config, self.anchors)
 
     def forward(self, x, target):
         class_out, box_out = self.model(x)
diff --git a/effdet/iou_loss.py b/effdet/iou_loss.py
new file mode 100644
index 00000000..fc13aeb6
--- /dev/null
+++ b/effdet/iou_loss.py
@@ -0,0 +1,122 @@
+'''
+Based on:
+ https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/86a370aa2cadea6ba7e5dffb2efc4bacc4c863ea/utils/box/box_utils.py#L47
+
+ Distance-IoU Loss: Faster and Better Learning for Bounding Box Regression
+ https://arxiv.org/pdf/1911.08287.pdf
+ Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression
+ https://giou.stanford.edu/GIoU.pdf
+ UnitBox: An Advanced Object Detection Network
+ https://arxiv.org/pdf/1608.01471.pdf
+
+ Important!!! (in case of c_iou_loss)
+ targets -> bboxes1, preds -> bboxes2
+ '''
+
+import torch
+from torch import nn
+import numpy as np
+
+eps = 10e-16
+
+
+def compute_iou(bboxes1, bboxes2):
+    "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]"
+    assert bboxes1.size(0) == bboxes2.size(0)
+    area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
+    area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
+    min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2])
+    max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0])
+    min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3])
+    max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1])
+
+    inter = torch.where(min_x2 - max_x1 > 0, min_x2 - max_x1, torch.tensor(0.)) * \
+            torch.where(min_y2 - max_y1 > 0, min_y2 - max_y1, torch.tensor(0.))
+    union = area1 + area2 - inter
+    iou = inter / union
+    iou = torch.clamp(iou, min=0, max=1.0)
+    return iou
+
+
+def compute_g_iou(bboxes1, bboxes2):
+    "box1 of shape [N, 4] and box2 of shape [N, 4]"
+    #assert bboxes1.size(0) == bboxes2.size(0)
+    area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
+    area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
+    min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2])
+    max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0])
+    min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3])
+    max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1])
+    inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0)
+    union = area1 + area2 - inter
+    C = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) * \
+        (torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1]))
+    g_iou = inter / union - (C - union) / C
+    g_iou = torch.clamp(g_iou, min=0, max=1.0)
+    return g_iou
+
+
+def compute_d_iou(bboxes1, bboxes2):
+    "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]"
+    #assert bboxes1.size(0) == bboxes2.size(0)
+    area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
+    area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
+    min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2])
+    max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0])
+    min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3])
+    max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1])
+    inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0)
+    union = area1 + area2 - inter
+    center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2
+    center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2
+    center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2
+    center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2
+
+    # squared euclidian distance between the target and predicted bboxes
+    d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2
+    # squared length of the diagonal of the minimum bbox that encloses both bboxes
+    c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + (
+            torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2
+    d_iou = inter / union - d_2 / c_2
+    d_iou = torch.clamp(d_iou, min=-1.0, max=1.0)
+
+    return d_iou
+
+
+def compute_c_iou(bboxes1, bboxes2):
+    "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]"
+    #assert bboxes1.size(0) == bboxes2.size(0)
+    w1 = bboxes1[:, 2] - bboxes1[:, 0]
+    h1 = bboxes1[:, 3] - bboxes1[:, 1]
+    w2 = bboxes2[:, 2] - bboxes2[:, 0]
+    h2 = bboxes2[:, 3] - bboxes2[:, 1]
+    area1 = w1 * h1
+    area2 = w2 * h2
+    min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2])
+    max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0])
+    min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3])
+    max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1])
+
+    inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0)
+    union = area1 + area2 - inter
+
+    center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2
+    center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2
+    center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2
+    center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2
+    # squared euclidian distance between the target and predicted bboxes
+    d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2
+    # squared length of the diagonal of the minimum bbox that encloses both bboxes
+    c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + (
+            torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2
+    iou = inter / union
+    v = 4 / np.pi ** 2 * (np.arctan(w1 / h1) - np.arctan(w2 / h2)) ** 2
+    with torch.no_grad():
+        S = 1 - iou
+        alpha = v / (S + v + eps)
+    c_iou = iou - (d_2 / c_2 + alpha * v)
+    c_iou = torch.clamp(c_iou, min=-1.0, max=1.0)
+    return c_iou
+
+
+
diff --git a/effdet/loss.py b/effdet/loss.py
index cce77875..c409b5be 100644
--- a/effdet/loss.py
+++ b/effdet/loss.py
@@ -3,7 +3,8 @@
 import torch.nn.functional as F
 
 from typing import Optional, List
-
+from .anchors import decode_box_outputs
+from .iou_loss import *
 
 def focal_loss(logits, targets, alpha: float, gamma: float, normalizer):
     """Compute the focal loss between `logits` and the golden `target` values.
@@ -119,8 +120,35 @@ def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1):
     return box_loss
 
 
+
+class IouLoss(nn.Module):
+
+    def __init__(self, losstype='Giou', reduction='mean'):
+        super(IouLoss, self).__init__()
+        self.reduction = reduction
+        self.loss = losstype
+
+    def forward(self, target_bboxes, pred_bboxes):
+        num = target_bboxes.shape[0]
+        if self.loss == 'Iou':
+            loss = torch.sum(1.0 - compute_iou(target_bboxes, pred_bboxes))
+        else:
+            if self.loss == 'Giou':
+                loss = torch.sum(1.0 - compute_g_iou(target_bboxes, pred_bboxes))
+            else:
+                if self.loss == 'Diou':
+                    loss = torch.sum(1.0 - compute_d_iou(target_bboxes, pred_bboxes))
+                else:
+                    loss = torch.sum(1.0 - compute_c_iou(target_bboxes, pred_bboxes))
+
+        if self.reduction == 'mean':
+            return loss / num
+        else:
+            return loss
+
+
 class DetectionLoss(nn.Module):
-    def __init__(self, config):
+    def __init__(self, config, anchors, use_iou_loss = False):
         super(DetectionLoss, self).__init__()
         self.config = config
         self.num_classes = config.num_classes
@@ -128,6 +156,10 @@ def __init__(self, config):
         self.gamma = config.gamma
         self.delta = config.delta
         self.box_loss_weight = config.box_loss_weight
+        self.use_iou_loss = use_iou_loss
+        if self.use_iou_loss:
+            self.anchors = anchors
+            self.iou_loss = IouLoss()
 
     def forward(
             self, cls_outputs: List[torch.Tensor], box_outputs: List[torch.Tensor],
@@ -161,6 +193,11 @@ def forward(
 
         cls_losses = []
         box_losses = []
+        if self.use_iou_loss:
+            box_outputs_list = []
+            cls_targets_list = []
+            box_targets_list = []
+
         for l in range(levels):
             cls_targets_at_level = cls_targets[l]
             box_targets_at_level = box_targets[l]
@@ -182,12 +219,29 @@ def forward(
             cls_loss = cls_loss.view(bs, height, width, -1, self.num_classes)
             cls_loss *= (cls_targets_at_level != -2).unsqueeze(-1).float()
             cls_losses.append(cls_loss.sum())
+            if not self.use_iou_loss:
+                box_losses.append(_box_loss(
+                    box_outputs[l].permute(0, 2, 3, 1),
+                    box_targets_at_level,
+                    num_positives_sum,
+                    delta=self.delta))
+
+            else:
+                box_outputs_list.append(box_outputs[l].permute(0, 2, 3, 1).reshape([bs, -1, 4]))
+                cls_targets_list.append(cls_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 1]))
+                box_targets_list.append(box_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 4]))
+
+
+        if self.use_iou_loss:
+            # apply bounding box regression to anchors
+            for k in range(box_outputs_list.shape[0]):
+                pred_boxes = decode_box_outputs(box_outputs_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True)
+                target_boxes = decode_box_outputs(box_targets_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True)
+                # indices where an anchor is assigned to target box
+                indices = box_targets_list[k] == 0.0
+                pred_boxes = torch.clamp(pred_boxes, 0)
+                box_losses.append(self.iou_loss(target_boxes[indices.view(-1)], pred_boxes[indices.view(-1)]))
 
-            box_losses.append(_box_loss(
-                box_outputs[l].permute(0, 2, 3, 1),
-                box_targets_at_level,
-                num_positives_sum,
-                delta=self.delta))
 
         # Sum per level losses to total loss.
         cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)