Skip to content

Commit e560ef7

Browse files
committed
Fixing tutorial with DataLoader
Updating the tutorial and README with more relevant/correct information. Minor fix to one part of `MapDataPipe` documentation as well. Fixes #352
1 parent 4b5e1da commit e560ef7

File tree

3 files changed

+87
-17
lines changed

3 files changed

+87
-17
lines changed

README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,30 @@ Q: What should I do if the existing set of DataPipes does not do what I need?
177177
A: You can
178178
[implement your own custom DataPipe](https://pytorch.org/data/main/tutorial.html#implementing-a-custom-datapipe). If you
179179
believe your use case is common enough such that the community can benefit from having your custom DataPipe added to
180-
this library, feel free to open a GitHub issue.
180+
this library, feel free to open a GitHub issue. We will be happy to discuss!
181181

182-
Q: What happens when the `Shuffler`/`Batcher` DataPipes are used with DataLoader?
182+
Q: What happens when the `Shuffler` DataPipe is used with DataLoader?
183183

184-
A: If you choose those DataPipes while setting `shuffle=True`/`batch_size>1` for DataLoader, your samples will be
185-
shuffled/batched more than once. You should choose one or the other.
184+
A. In order to enable shuffling, you need to add a `Shuffler` to your DataPipe line. Then, by default, shuffling will
185+
happen at the point where you specified as long as you do not set `shuffle=False` within DataLoader.
186+
187+
Q: What happens when the `Batcher` DataPipe is used with DataLoader?
188+
189+
A: If you choose to use `Batcher` while setting `batch_size > 1` for DataLoader, your samples will be batched more than
190+
once. You should choose one or the other.
191+
192+
Q: Why are there fewer built-in `MapDataPipes` than `IterDataPipes`?
193+
194+
A: By design, there are fewer `MapDataPipes` than `IterDataPipes` to avoid duplicate implementations of the same
195+
functionalities as `MapDataPipe`. We encourage users to use the built-in `IterDataPipe` for various functionalities, and
196+
convert it to `MapDataPipe` as needed.
186197

187198
Q: How is multiprocessing handled with DataPipes?
188199

189200
A: Multi-process data loading is still handled by DataLoader, see the
190201
[DataLoader documentation for more details](https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading).
202+
If you would like to shard data across processes, use `ShardingFilter` and provide a `worker_init_fn` as shown in the
203+
[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader).
191204

192205
Q: What is the upcoming plan for DataLoader?
193206

docs/source/torchdata.datapipes.map.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ corresponding label from a folder on the disk.
1515

1616
By design, there are fewer ``MapDataPipe`` than ``IterDataPipe`` to avoid duplicate implementations of the same
1717
functionalities as ``MapDataPipe``. We encourage users to use the built-in ``IterDataPipe`` for various functionalities,
18-
and convert it to ``MapDataPipe`` as needed using ``MapToIterConverter`` or ``.to_iter_datapipe()``.
19-
If you have any question about usage or best practices while using `MapDataPipe`, feel free to ask on the PyTorch forum
20-
under the `'data' category <https://discuss.pytorch.org/c/data/37>`_.
18+
and convert it to ``MapDataPipe`` as needed using :class:`.IterToMapConverter` or ``.to_map_datapipe()``.
19+
If you have any question about usage or best practices while using ``MapDataPipe``, feel free to ask on the PyTorch
20+
forum under the `'data' category <https://discuss.pytorch.org/c/data/37>`_.
2121

2222
We are open to add additional ``MapDataPipe`` where the operations can be lazily executed and ``__len__`` can be
2323
known in advance. Feel free to make suggestions with description of your use case in

docs/source/tutorial.rst

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,33 @@ For this example, we will first have a helper function that generates some CSV f
8080
row_data['label'] = random.randint(0, 9)
8181
writer.writerow(row_data)
8282
83-
Next, we will build our DataPipes to read and parse through the generated CSV files:
83+
Next, we will build our DataPipes to read and parse through the generated CSV files. Note that we prefer to have
84+
pass defined functions to DataPipes rather than lambda functions because the formers are serializable with `pickle`.
8485

8586
.. code:: python
8687
8788
import numpy as np
8889
import torchdata.datapipes as dp
8990
91+
def filter_for_data(filename):
92+
return "sample_data" in filename and filename.endswith(".csv")
93+
94+
def row_processer(row):
95+
return {"label": np.array(row[0], np.int32), "data": np.array(row[1:], dtype=np.float64)}
96+
9097
def build_datapipes(root_dir="."):
9198
datapipe = dp.iter.FileLister(root_dir)
92-
datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
93-
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
99+
datapipe = datapipe.filter(filter_fn=filter_for_data)
100+
datapipe = datapipe.open_files(mode='rt')
94101
datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
95-
datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
96-
"data": np.array(row[1:], dtype=np.float64)})
102+
# Shuffle will happen as long as you do NOT set `shuffle=False` later in the DataLoader
103+
datapipe = datapipe.shuffle()
104+
datapipe = datapipe.map(row_processer)
97105
return datapipe
98106
99-
Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader.
107+
Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that
108+
if you choose to use `Batcher` while setting `batch_size > 1` for DataLoader, your samples will be
109+
batched more than once. You should choose one or the other.
100110

101111
.. code:: python
102112
@@ -105,20 +115,67 @@ Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe
105115
if __name__ == '__main__':
106116
num_files_to_generate = 3
107117
for i in range(num_files_to_generate):
108-
generate_csv(file_label=i)
118+
generate_csv(file_label=i, num_rows=10, num_features=3)
109119
datapipe = build_datapipes()
110-
dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
120+
dl = DataLoader(dataset=datapipe, batch_size=5, num_workers=2)
111121
first = next(iter(dl))
112122
labels, features = first['label'], first['data']
113123
print(f"Labels batch shape: {labels.size()}")
114124
print(f"Feature batch shape: {features.size()}")
125+
print(f"{labels = }\n{features = }")
126+
n_sample = 0
127+
for row in iter(dl):
128+
n_sample += 1
129+
print(f"{n_sample = }")
115130
116131
The following statements will be printed to show the shapes of a single batch of labels and features.
117132

118133
.. code::
119134
120-
Labels batch shape: 50
121-
Feature batch shape: torch.Size([50, 20])
135+
Labels batch shape: torch.Size([5])
136+
Feature batch shape: torch.Size([5, 3])
137+
labels = tensor([8, 9, 5, 9, 7], dtype=torch.int32)
138+
features = tensor([[0.2867, 0.5973, 0.0730],
139+
[0.7890, 0.9279, 0.7392],
140+
[0.8930, 0.7434, 0.0780],
141+
[0.8225, 0.4047, 0.0800],
142+
[0.1655, 0.0323, 0.5561]], dtype=torch.float64)
143+
n_sample = 12
144+
145+
The reason why ``n_sample = 12`` is because ``ShardingFilter`` (``datapipe.sharding_filter()``) was not used, such that
146+
each worker will independently return all samples. In this case, there are 10 rows per file and 3 files, with a
147+
batch size of 5, that gives us 6 batches per worker. With 2 workers, we get 12 total batches from the ``DataLoader``.
148+
149+
In order for DataPipe sharding to work with ``DataLoader``, we need to add the following. It is crucial to add
150+
`ShardingFilter` after `Shuffler` to ensure that all worker processes have the same order of data for sharding.
151+
152+
.. code:: python
153+
154+
def build_datapipes(root_dir="."):
155+
datapipe = ...
156+
# Add the following line to `build_datapipes`
157+
# Note that it is somewhere after `Shuffler` in the DataPipe line
158+
datapipe = datapipe.sharding_filter()
159+
return datapipe
160+
161+
def worker_init_fn(worker_id):
162+
info = torch.utils.data.get_worker_info()
163+
num_workers = info.num_workers
164+
datapipe = info.dataset
165+
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
166+
167+
# Pass `worker_init_fn` into `DataLoader` within '__main__'
168+
...
169+
dl = DataLoader(dataset=datapipe, shuffle=True, batch_size=5, num_workers=2, worker_init_fn=worker_init_fn)
170+
...
171+
172+
When we re-run, we will get:
173+
174+
.. code::
175+
176+
...
177+
n_sample = 6
178+
122179
123180
You can find more DataPipe implementation examples for various research domains `on this page <torchexamples.html>`_.
124181

0 commit comments

Comments
 (0)