diff --git a/test/test_ops.py b/test/test_ops.py index 1c39468aa5b..cd7f992a7c8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -557,8 +557,10 @@ def test_nms_cuda_float16(self): keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) assert_equal(keep32, keep16) - def test_batched_nms_implementations(self): + @pytest.mark.parametrize("seed", range(10)) + def test_batched_nms_implementations(self, seed): """Make sure that both implementations of batched_nms yield identical results""" + torch.random.manual_seed(seed) num_boxes = 1000 iou_threshold = 0.9