Skip to content

Commit 66ba84a

Browse files
committed
remove decoder
1 parent b55b331 commit 66ba84a

File tree

1 file changed

+7
-24
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+7
-24
lines changed

torchvision/prototype/datasets/_builtin/lsun.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1-
import functools
21
import io
32
import pathlib
43
import re
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
4+
from typing import Any, Dict, List, Tuple, Iterator
65

7-
import torch
86
from torchdata.datapipes.iter import IterDataPipe, Mapper, OnDiskCacheHolder, Concater, IterableWrapper
97
from torchvision.prototype.datasets.utils import (
108
Dataset,
119
DatasetConfig,
1210
DatasetInfo,
1311
HttpResource,
1412
OnlineResource,
15-
DatasetType,
1613
)
1714
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
18-
from torchvision.prototype.features import Label
15+
from torchvision.prototype.features import Label, EncodedImage
1916

2017
# We need lmdb.Environment as annotation, but lmdb is an optional requirement at import
2118
try:
@@ -79,7 +76,6 @@ class Lsun(Dataset):
7976
def _make_info(self) -> DatasetInfo:
8077
return DatasetInfo(
8178
"lsun",
82-
type=DatasetType.IMAGE,
8379
categories=(
8480
"bedroom",
8581
"bridge",
@@ -140,37 +136,24 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
140136

141137
_FOLDER_PATTERN = re.compile(r"(?P<category>\w*?)_(?P<split>(train|val))_lmdb")
142138

143-
def _collate_and_decode_sample(
144-
self,
145-
data: Tuple[str, bytes, io.BytesIO],
146-
*,
147-
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
148-
) -> Dict[str, Any]:
139+
def _prepare_sample(self, data: Tuple[str, bytes, io.BytesIO]) -> Dict[str, Any]:
149140
path, key, buffer = data
150141

151142
match = self._FOLDER_PATTERN.match(pathlib.Path(path).parent.name)
152-
if match:
153-
category = match["category"]
154-
label = Label(self.categories.index(category), category=category)
155-
else:
156-
label = None
143+
label = Label.from_category(match["category"], categories=self.categories) if match else None
157144

158145
return dict(
159146
path=path,
160147
key=key,
161-
image=decoder(buffer) if decoder else buffer,
148+
image=EncodedImage.from_file(buffer),
162149
label=label,
163150
)
164151

165152
def _filepath_fn(self, path: str) -> str:
166153
return str(pathlib.Path(path) / "keys.cache")
167154

168155
def _make_datapipe(
169-
self,
170-
resource_dps: List[IterDataPipe],
171-
*,
172-
config: DatasetConfig,
173-
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
156+
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
174157
) -> IterDataPipe[Dict[str, Any]]:
175158
dp = Concater(*resource_dps)
176159

@@ -183,4 +166,4 @@ def _make_datapipe(
183166
dp = hint_sharding(dp)
184167
dp = hint_shuffling(dp)
185168
dp = LmdbReader(dp)
186-
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
169+
return Mapper(dp, self._prepare_sample)

0 commit comments

Comments
 (0)