Open
Description
📚 The doc issue
Hi, I'd like to be able to download prototype datasets to my AIStore cluster and use AISFileLister
, AISFileLoader
.
Also, how can I gain access to CifarFileReader
and associated functions if my resource is AIStore? For example, currently I did the following:
- created an AIStore cluster
- downloaded data (cifar10) locally
- put data in aistore
- Created initial dp with
AISFileLister
,AISFileLoader
- copied and pasted all the functions and classes defined in the cifar file to be able to use
Filter(dp, _is_data_file)
,Mapper(dp, _unpickle)
,CifarFileReader(dp, labels_key="labels")
,Mapper(dp, _prepare_sample)
My final script is below, but I assume there must be a way to make it more concise?
from torchdata.datapipes.iter import AISFileLister, AISFileLoader
# list of prefixes which contain data
image_prefix = ["ais://cifar10/"]
# Listing all files starting with these prefixes on AIStore
dp_urls = AISFileLister(url="http://localhost:51080", source_datapipe=image_prefix)
# list obj urls
print(list(dp_urls))
# loading data using AISFileLoader
dp = AISFileLoader(url="http://localhost:51080", source_datapipe=dp_urls)
import io
import pathlib
import pickle
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
from torchdata.datapipes.iter import TarArchiveLoader, Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Image, Label
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
)
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
self.datapipe = datapipe
self.labels_key = labels_key
def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
for mapping in self.datapipe:
image_arrays = mapping["data"].reshape((-1, 3, 32, 32))
category_idcs = mapping[self.labels_key]
yield from iter(zip(image_arrays, category_idcs))
categories = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
split = "train"
def _is_data_file(data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if split == "train" else "test")
def _unpickle(data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content
def _prepare_sample(data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=categories),
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dp = TarArchiveLoader(dp)
dp = Filter(dp, _is_data_file)
dp = Mapper(dp, _unpickle)
dp = CifarFileReader(dp, labels_key="labels")
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
dp = Mapper(dp, _prepare_sample)
dp = dp.batch(8).collate()
for x in dp:
print(x.keys())
images = x['image'].to(device)
labels = x['label'].to(device)
print(images.shape)
break
Suggest a potential alternative/fix
Demonstrating how to use the prototype dataset classes if you have your own data stored in a matching format