1
- import functools
2
1
import io
3
2
import pathlib
4
3
import re
5
- from typing import Any , Callable , Dict , List , Optional , Tuple , Iterator
4
+ from typing import Any , Dict , List , Tuple , Iterator
6
5
7
- import torch
8
6
from torchdata .datapipes .iter import IterDataPipe , Mapper , OnDiskCacheHolder , Concater , IterableWrapper
9
7
from torchvision .prototype .datasets .utils import (
10
8
Dataset ,
11
9
DatasetConfig ,
12
10
DatasetInfo ,
13
11
HttpResource ,
14
12
OnlineResource ,
15
- DatasetType ,
16
13
)
17
14
from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling
18
- from torchvision .prototype .features import Label
15
+ from torchvision .prototype .features import Label , EncodedImage
19
16
20
17
# We need lmdb.Environment as annotation, but lmdb is an optional requirement at import
21
18
try :
@@ -79,7 +76,6 @@ class Lsun(Dataset):
79
76
def _make_info (self ) -> DatasetInfo :
80
77
return DatasetInfo (
81
78
"lsun" ,
82
- type = DatasetType .IMAGE ,
83
79
categories = (
84
80
"bedroom" ,
85
81
"bridge" ,
@@ -140,37 +136,24 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
140
136
141
137
_FOLDER_PATTERN = re .compile (r"(?P<category>\w*?)_(?P<split>(train|val))_lmdb" )
142
138
143
- def _collate_and_decode_sample (
144
- self ,
145
- data : Tuple [str , bytes , io .BytesIO ],
146
- * ,
147
- decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
148
- ) -> Dict [str , Any ]:
139
+ def _prepare_sample (self , data : Tuple [str , bytes , io .BytesIO ]) -> Dict [str , Any ]:
149
140
path , key , buffer = data
150
141
151
142
match = self ._FOLDER_PATTERN .match (pathlib .Path (path ).parent .name )
152
- if match :
153
- category = match ["category" ]
154
- label = Label (self .categories .index (category ), category = category )
155
- else :
156
- label = None
143
+ label = Label .from_category (match ["category" ], categories = self .categories ) if match else None
157
144
158
145
return dict (
159
146
path = path ,
160
147
key = key ,
161
- image = decoder (buffer ) if decoder else buffer ,
148
+ image = EncodedImage . from_file (buffer ),
162
149
label = label ,
163
150
)
164
151
165
152
def _filepath_fn (self , path : str ) -> str :
166
153
return str (pathlib .Path (path ) / "keys.cache" )
167
154
168
155
def _make_datapipe (
169
- self ,
170
- resource_dps : List [IterDataPipe ],
171
- * ,
172
- config : DatasetConfig ,
173
- decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
156
+ self , resource_dps : List [IterDataPipe ], * , config : DatasetConfig
174
157
) -> IterDataPipe [Dict [str , Any ]]:
175
158
dp = Concater (* resource_dps )
176
159
@@ -183,4 +166,4 @@ def _make_datapipe(
183
166
dp = hint_sharding (dp )
184
167
dp = hint_shuffling (dp )
185
168
dp = LmdbReader (dp )
186
- return Mapper (dp , functools . partial ( self ._collate_and_decode_sample , decoder = decoder ) )
169
+ return Mapper (dp , self ._prepare_sample )
0 commit comments