from typing import List import torch import torchvision import torch_tensorrt # torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph) def mask_to_index(tensor: torch.Tensor): """ Convert a mask into an index. For multi-dimensional masks, the index is valid on a tensor flattened over those dimensions. Notes: - For smaller tensors (~5000 elements) on the GPU, it is faster to transfer it to the CPU and back again. - Break-even is around 2 uses of the index. Arguments: tensor {torch.Tensor} -- The mask. Returns: torch.Tensor -- The index. """ # return torch.arange(tensor.numel(), device=tensor.device).view(tensor.shape).masked_select(tensor) # This triggers https://github.com/pytorch/pytorch/issues/60539 on the CPU for tensors bigger than 2>>15 return torch.arange(tensor.numel(), device=tensor.device).view(tensor.shape)[tensor] def standardise_boxes(boxes: torch.Tensor) -> torch.Tensor: """ Standardise the boxes (x1 < x2, y1 < y2). The input tensor `boxes` may be modified. Note: break-even for CUDA vs CPU seems to be around 3000 boxes. Arguments: boxes {torch.tensor} -- (N,K*H*W,4) K*H*W boxes for each image in the batch (N). The input tensor will be automatically transposed if necessary. Returns: torch.tensor -- The same boxes, but normalised. """ b = boxes.to(boxes.device if boxes.device.type == 'cuda' and boxes.numel() > 12000 else torch.device('cpu')).flatten(end_dim=-2).clone() index_x, index_y = mask_to_index(b[:, 2] < b[:, 0]), mask_to_index(b[:, 3] < b[:, 1]) b[:, 0][index_x], b[:, 2][index_x] = b[:, 2][index_x], b[:, 0][index_x] b[:, 1][index_y], b[:, 3][index_y] = b[:, 3][index_y], b[:, 1][index_y] return b.view(boxes.shape).to(boxes.device) # Send the tensor back to the appropriate device. def convert_boxes_to_roi_format(boxes: List[torch.Tensor], scale: torch.Tensor = torch.ones(1, dtype=torch.float)): """ Convert boxes to the format expected by roi_align (from torchvision). Arguments: boxes {List[Tensor]} -- A list of tensors, one for each feature_map entry. scale {Tensor[float]} -- Scale factor for the boxes to align with the feature_map size. Should be of length 1, 2 or 4. Returns: Tensor -- The collated boxes with their feature_map index. """ if scale.numel() == 2: scale = torch.cat((scale, scale)) concat_boxes = (boxes[0] if len(boxes) == 1 else torch.cat([b for b in boxes], dim=0)) * scale.to(device=boxes[0].device, dtype=boxes[0].dtype) temp = [] for i, b in enumerate(boxes): temp.append(torch.full_like(b[:, :1], i)) ids = temp[0] if len(temp) == 1 else torch.cat(temp, dim=0) rois = torch.cat([ids, concat_boxes], dim=1) return rois class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): boxes = convert_boxes_to_roi_format([standardise_boxes(torch.rand(10, 4)) for i in range(2)]).to(x.device) return torchvision.ops.roi_align(x, boxes, output_size=(6, 9), spatial_scale=1.0, sampling_ratio=2) m = MyModel() print(m(torch.rand((2,3,10,10), device='cuda')).shape) scr_m = torch.jit.script(m) print(scr_m(torch.rand((2,3,10,10), device='cuda')).shape) trt_m = torch_tensorrt.compile(scr_m, inputs=[torch_tensorrt.Input((2, 3, 10, 10)), ], truncate_long_and_double=True)