-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
The preferred way to structure the Transformation classes is to put the initialization of random weights/params in a static get_params() method. The method should receive any hyper parameter necessary for the sampling and it should return all the necessary random variables. This method should be called by forward() during the transformation process. This is an example of how this would look:
vision/torchvision/transforms/transforms.py
Lines 530 to 531 in 9e71fda
| @staticmethod | |
| def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: |
vision/torchvision/transforms/transforms.py
Line 589 in 9e71fda
| i, j, h, w = self.get_params(img, self.size) |
Unfortunately many forward() methods call directly torch.rand. Here are a few examples:
vision/torchvision/transforms/transforms.py
Lines 452 to 453 in 9e71fda
| def forward(self, img): | |
| if self.p < torch.rand(1): |
vision/torchvision/transforms/transforms.py
Line 619 in 9e71fda
| if torch.rand(1) < self.p: |
vision/torchvision/transforms/transforms.py
Line 649 in 9e71fda
| if torch.rand(1) < self.p: |
vision/torchvision/transforms/transforms.py
Line 700 in 9e71fda
| if torch.rand(1) < self.p: |
vision/torchvision/transforms/transforms.py
Line 1454 in 9e71fda
| if torch.rand(1) < self.p: |
vision/torchvision/transforms/transforms.py
Line 1560 in 9e71fda
| if torch.rand(1) < self.p: |
There might be potentially others. We should refactor the codebase so that all of the above calls happen within a static get_params() method. See #3065 RandomInvert for an example on how to structure it.
cc @vfdev-5