diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c362b53981f..f0ff8ef17cb 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1019,13 +1019,14 @@ def dtd(root, config): return num_samples_map[config["split"], config["fold"]] -# @register_mock -def fer2013(info, root, config): - num_samples = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def fer2013(root, config): + split = config["split"] + num_samples = 5 if split == "train" else 3 - path = root / f"{config.split}.csv" + path = root / f"{split}.csv" with open(path, "w", newline="") as file: - field_names = ["emotion"] if config.split == "train" else [] + field_names = ["emotion"] if split == "train" else [] field_names.append("pixels") file.write(",".join(field_names) + "\n") @@ -1035,7 +1036,7 @@ def fer2013(info, root, config): rowdict = { "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) } - if config.split == "train": + if split == "train": rowdict["emotion"] = int(torch.randint(7, ())) writer.writerow(rowdict) diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index a5bfa681d02..ca30b78e609 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,11 +1,10 @@ -from typing import Any, Dict, List, cast +import pathlib +from typing import Any, Dict, List, cast, Union import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, KaggleDownloadResource, ) @@ -15,26 +14,40 @@ ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info + +NAME = "fer2013" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral")) -class FER2013(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fer2013", - homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", - categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), - valid_options=dict(split=("train", "test")), - ) + +@register_dataset(NAME) +class FER2013(Dataset2): + """FER 2013 Dataset + homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) _CHECKSUMS = { "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = KaggleDownloadResource( - cast(str, self.info.homepage), - file_name=f"{config.split}.csv.zip", - sha256=self._CHECKSUMS[config.split], + "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", + file_name=f"{self._split}.csv.zip", + sha256=self._CHECKSUMS[self._split], ) return [archive] @@ -43,17 +56,15 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: return dict( image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self.categories) if label_id is not None else None, + label=Label(int(label_id), categories=self._categories) if label_id is not None else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVDictParser(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 28_709 if self._split == "train" else 3_589