diff --git a/tests/e2e/mnist.py b/tests/e2e/mnist.py index a99589659..91d1e0d21 100644 --- a/tests/e2e/mnist.py +++ b/tests/e2e/mnist.py @@ -19,7 +19,7 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, RandomSampler from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -127,7 +127,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=BATCH_SIZE) + return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000)) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=BATCH_SIZE) @@ -147,10 +147,11 @@ def test_dataloader(self): trainer = Trainer( accelerator="auto", # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs - max_epochs=5, + max_epochs=3, callbacks=[TQDMProgressBar(refresh_rate=20)], num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)), devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)), + replace_sampler_ddp=False, strategy="ddp", )