-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add MovingMNIST #2690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MovingMNIST #2690
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -435,6 +435,74 @@ def get_int(b: bytes) -> int: | |||||||||||
return int(codecs.encode(b, 'hex'), 16) | ||||||||||||
|
||||||||||||
|
||||||||||||
class MovingMNIST(VisionDataset): | ||||||||||||
"""MovingMNIST""" | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please a link (http://www.cs.toronto.edu/~nitish/unsupervised_video/) to the dataset like that
and define docstring |
||||||||||||
url = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy" | ||||||||||||
file = "moving_mnist.pt" | ||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
root: str, | ||||||||||||
transform: Optional[Callable] = None, | ||||||||||||
target_transform: Optional[Callable] = None, | ||||||||||||
download: bool = False): | ||||||||||||
super(MovingMNIST, self).__init__(root, transform, target_transform) | ||||||||||||
|
||||||||||||
if download: | ||||||||||||
self.download() | ||||||||||||
|
||||||||||||
if not self._check_exists(): | ||||||||||||
raise RuntimeError('Dataset not found.' + | ||||||||||||
' You can use download=True to download it') | ||||||||||||
|
||||||||||||
self.data = torch.load(os.path.join(self.processed_folder, self.file)) | ||||||||||||
|
||||||||||||
def download(self) -> None: | ||||||||||||
if self._check_exists(): | ||||||||||||
return | ||||||||||||
os.makedirs(self.raw_folder, exist_ok=True) | ||||||||||||
os.makedirs(self.processed_folder, exist_ok=True) | ||||||||||||
|
||||||||||||
filename = self.url.rpartition('/')[2] | ||||||||||||
file_path = os.path.join(self.raw_folder, filename) | ||||||||||||
if not os.path.isfile(file_path): | ||||||||||||
download_url(self.url, root=self.raw_folder, filename=filename) | ||||||||||||
data = read_npy_file(file_path) | ||||||||||||
|
||||||||||||
with open(os.path.join(self.processed_folder, self.file), 'wb') as f: | ||||||||||||
torch.save((data), f) | ||||||||||||
|
||||||||||||
def _check_exists(self) -> bool: | ||||||||||||
return (os.path.exists(os.path.join(self.processed_folder, | ||||||||||||
self.file))) | ||||||||||||
|
||||||||||||
|
||||||||||||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without a length, the dataset is not iterable.
Suggested change
|
||||||||||||
|
||||||||||||
def _transform_data(data, transform): | ||||||||||||
transformed_data = None | ||||||||||||
for i in range(data.size(0)): | ||||||||||||
img = Image.fromarray(data[i].numpy(), mode='L') | ||||||||||||
transformed_data = transform(img) if new_data is None else torch.cat([self.transform(img), new_data], dim=0) | ||||||||||||
return transformed_data | ||||||||||||
|
||||||||||||
sequence, target = self.data[index, :10], self.data[index, 10:] | ||||||||||||
|
||||||||||||
if self.transform is not None: | ||||||||||||
sequence = _transform_data(sequence, self.transform) | ||||||||||||
if self.target_transform is not None: | ||||||||||||
target = _transform_data(target, self.target_transform) | ||||||||||||
|
||||||||||||
return sequence, target | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def raw_folder(self) -> str: | ||||||||||||
return os.path.join(self.root, self.__class__.__name__, 'raw') | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def processed_folder(self) -> str: | ||||||||||||
return os.path.join(self.root, self.__class__.__name__, 'processed') | ||||||||||||
|
||||||||||||
def open_maybe_compressed_file(path: Union[str, IO]) -> IO: | ||||||||||||
"""Return a file object that possibly decompresses 'path' on the fly. | ||||||||||||
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. | ||||||||||||
|
@@ -494,3 +562,9 @@ def read_image_file(path: str) -> torch.Tensor: | |||||||||||
assert(x.dtype == torch.uint8) | ||||||||||||
assert(x.ndimension() == 3) | ||||||||||||
return x | ||||||||||||
|
||||||||||||
def read_npy_file(path: str) -> torch.Tensor: | ||||||||||||
with open(path, 'rb') as f: | ||||||||||||
x = torch.tensor(np.load(f)) | ||||||||||||
assert(x.dtype == torch.uint8) | ||||||||||||
return x | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we inherit it from
MNIST
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use this as a video dataset, we shouldn't. We probably need to use
MNIST
to generate the training split though.