diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 7121ee964fa..699e64ae0c1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -268,8 +268,8 @@ def __init__( def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size - torch._assert(h == self.image_size, "Wrong image height!") - torch._assert(w == self.image_size, "Wrong image width!") + torch._assert(h == self.image_size, f"Wrong image height, expected {self.image_size} but got {h}!") + torch._assert(w == self.image_size, f"Wrong image width, expected {self.image_size} but got {w}!") n_h = h // p n_w = w // p