Skip to content

Commit 24c0a14

Browse files
authored
use upstream torchdata datapipes in prototype datasets (#5570)
1 parent 3b0b6c0 commit 24c0a14

File tree

3 files changed

+13
-83
lines changed

3 files changed

+13
-83
lines changed

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33
import re
44
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast
55

6-
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer
7-
from torchdata.datapipes.iter import TarArchiveReader
6+
from torchdata.datapipes.iter import (
7+
IterDataPipe,
8+
LineReader,
9+
IterKeyZipper,
10+
Mapper,
11+
Filter,
12+
Demultiplexer,
13+
TarArchiveReader,
14+
Enumerator,
15+
)
816
from torchvision.prototype.datasets.utils import (
917
Dataset,
1018
DatasetConfig,
@@ -16,7 +24,6 @@
1624
INFINITE_BUFFER_SIZE,
1725
BUILTIN_DIR,
1826
path_comparator,
19-
Enumerator,
2027
getitem,
2128
read_mat,
2229
hint_sharding,

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,9 @@
66
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence
77

88
import torch
9-
from torchdata.datapipes.iter import (
10-
IterDataPipe,
11-
Demultiplexer,
12-
Mapper,
13-
Zipper,
14-
)
15-
from torchvision.prototype.datasets.utils import (
16-
Dataset,
17-
DatasetConfig,
18-
DatasetInfo,
19-
HttpResource,
20-
OnlineResource,
21-
)
22-
from torchvision.prototype.datasets.utils._internal import (
23-
Decompressor,
24-
INFINITE_BUFFER_SIZE,
25-
hint_sharding,
26-
hint_shuffling,
27-
)
9+
from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor
10+
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
11+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling
2812
from torchvision.prototype.features import Image, Label
2913
from torchvision.prototype.utils._internal import fromfile
3014

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
import enum
21
import functools
3-
import gzip
4-
import lzma
5-
import os
6-
import os.path
72
import pathlib
83
import pickle
94
from typing import BinaryIO
@@ -16,7 +11,6 @@
1611
TypeVar,
1712
Iterator,
1813
Dict,
19-
Optional,
2014
IO,
2115
Sized,
2216
)
@@ -35,11 +29,9 @@
3529
"BUILTIN_DIR",
3630
"read_mat",
3731
"MappingIterator",
38-
"Enumerator",
3932
"getitem",
4033
"path_accessor",
4134
"path_comparator",
42-
"Decompressor",
4335
"read_flo",
4436
"hint_sharding",
4537
]
@@ -75,15 +67,6 @@ def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
7567
yield from iter(mapping.values() if self.drop_key else mapping.items())
7668

7769

78-
class Enumerator(IterDataPipe[Tuple[int, D]]):
79-
def __init__(self, datapipe: IterDataPipe[D], start: int = 0) -> None:
80-
self.datapipe = datapipe
81-
self.start = start
82-
83-
def __iter__(self) -> Iterator[Tuple[int, D]]:
84-
yield from enumerate(self.datapipe, self.start)
85-
86-
8770
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
8871
for item in items:
8972
obj = obj[item]
@@ -123,50 +106,6 @@ def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -
123106
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
124107

125108

126-
class CompressionType(enum.Enum):
127-
GZIP = "gzip"
128-
LZMA = "lzma"
129-
130-
131-
class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]):
132-
types = CompressionType
133-
134-
_DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = {
135-
types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)),
136-
types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)),
137-
}
138-
139-
def __init__(
140-
self,
141-
datapipe: IterDataPipe[Tuple[str, BinaryIO]],
142-
*,
143-
type: Optional[Union[str, CompressionType]] = None,
144-
) -> None:
145-
self.datapipe = datapipe
146-
if isinstance(type, str):
147-
type = self.types(type.upper())
148-
self.type = type
149-
150-
def _detect_compression_type(self, path: str) -> CompressionType:
151-
if self.type:
152-
return self.type
153-
154-
# TODO: this needs to be more elaborate
155-
ext = os.path.splitext(path)[1]
156-
if ext == ".gz":
157-
return self.types.GZIP
158-
elif ext == ".xz":
159-
return self.types.LZMA
160-
else:
161-
raise RuntimeError("FIXME")
162-
163-
def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]:
164-
for path, file in self.datapipe:
165-
type = self._detect_compression_type(path)
166-
decompressor = self._DECOMPRESSORS[type]
167-
yield path, decompressor(file)
168-
169-
170109
class PicklerDataPipe(IterDataPipe):
171110
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None:
172111
self.source_datapipe = source_datapipe

0 commit comments

Comments
 (0)