diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 8cf82c30a0f..22ef536a85c 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -435,6 +435,74 @@ def get_int(b: bytes) -> int: return int(codecs.encode(b, 'hex'), 16) +class MovingMNIST(VisionDataset): + """MovingMNIST""" + url = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy" + file = "moving_mnist.pt" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False): + super(MovingMNIST, self).__init__(root, transform, target_transform) + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + self.data = torch.load(os.path.join(self.processed_folder, self.file)) + + def download(self) -> None: + if self._check_exists(): + return + os.makedirs(self.raw_folder, exist_ok=True) + os.makedirs(self.processed_folder, exist_ok=True) + + filename = self.url.rpartition('/')[2] + file_path = os.path.join(self.raw_folder, filename) + if not os.path.isfile(file_path): + download_url(self.url, root=self.raw_folder, filename=filename) + data = read_npy_file(file_path) + + with open(os.path.join(self.processed_folder, self.file), 'wb') as f: + torch.save((data), f) + + def _check_exists(self) -> bool: + return (os.path.exists(os.path.join(self.processed_folder, + self.file))) + + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + + def _transform_data(data, transform): + transformed_data = None + for i in range(data.size(0)): + img = Image.fromarray(data[i].numpy(), mode='L') + transformed_data = transform(img) if new_data is None else torch.cat([self.transform(img), new_data], dim=0) + return transformed_data + + sequence, target = self.data[index, :10], self.data[index, 10:] + + if self.transform is not None: + sequence = _transform_data(sequence, self.transform) + if self.target_transform is not None: + target = _transform_data(target, self.target_transform) + + return sequence, target + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'processed') + def open_maybe_compressed_file(path: Union[str, IO]) -> IO: """Return a file object that possibly decompresses 'path' on the fly. Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. @@ -494,3 +562,9 @@ def read_image_file(path: str) -> torch.Tensor: assert(x.dtype == torch.uint8) assert(x.ndimension() == 3) return x + +def read_npy_file(path: str) -> torch.Tensor: + with open(path, 'rb') as f: + x = torch.tensor(np.load(f)) + assert(x.dtype == torch.uint8) + return x \ No newline at end of file