From ac6c3ea4fa732cb50e94cb7c166fa0fd639d9a95 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 7 Sep 2023 16:47:13 -0700 Subject: [PATCH 1/2] Meta implementation for nms Signed-off-by: Edward Z. Yang --- torchvision/_meta_registrations.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 9831cfdcb45..8a25462ce15 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -2,6 +2,7 @@ import torch import torch.library +import torch._custom_ops # Ensure that torch.ops.torchvision is visible import torchvision.extension # noqa: F401 @@ -48,3 +49,25 @@ def meta_roi_align_backward( ), ) return grad.new_empty((batch_size, channels, height, width)) + +@torch._custom_ops.impl_abstract("torchvision::nms") +def meta_nms(dets, scores, iou_threshold): + torch._check( + dets.dim() == 2, + lambda: f"boxes should be a 2d tensor, got {dets.dim()}D" + ) + torch._check( + dets.size(1) == 4, + lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}" + ) + torch._check( + scores.dim() == 1, + lambda: f"scores should be a 1d tensor, got {scores.dim()}" + ) + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}" + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long) From 7c858e15fdccb481d8e03376b19783fbf2e1e2ba Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 8 Sep 2023 09:25:26 +0200 Subject: [PATCH 2/2] lint --- torchvision/_meta_registrations.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 8a25462ce15..7baece2ae2c 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -1,8 +1,8 @@ import functools import torch -import torch.library import torch._custom_ops +import torch.library # Ensure that torch.ops.torchvision is visible import torchvision.extension # noqa: F401 @@ -50,23 +50,15 @@ def meta_roi_align_backward( ) return grad.new_empty((batch_size, channels, height, width)) + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): - torch._check( - dets.dim() == 2, - lambda: f"boxes should be a 2d tensor, got {dets.dim()}D" - ) - torch._check( - dets.size(1) == 4, - lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}" - ) - torch._check( - scores.dim() == 1, - lambda: f"scores should be a 1d tensor, got {scores.dim()}" - ) + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") torch._check( dets.size(0) == scores.size(0), - lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}" + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", ) ctx = torch._custom_ops.get_ctx() num_to_keep = ctx.create_unbacked_symint()