Skip to content

Commit 42922cf

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Adding lock mechanism to prevent on_disk_cache downloading twice (#409)
Summary: Pull Request resolved: #409 Fixes #144 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D36489060 Pulled By: VitalyFedyunin fbshipit-source-id: d33624d2d38d4756789b1ebb09eb05d2b4efe6a6
1 parent a92f573 commit 42922cf

File tree

5 files changed

+183
-17
lines changed

5 files changed

+183
-17
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
urllib3 >= 1.25
22
requests
3+
portalocker >= 2.0.0

test/test_local_io.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import bz2
8+
import functools
89
import hashlib
910
import io
1011
import itertools
1112
import lzma
1213
import os
1314
import subprocess
1415
import tarfile
16+
import tempfile
17+
import time
1518
import unittest
1619
import warnings
1720
import zipfile
@@ -22,6 +25,8 @@
2225
import expecttest
2326

2427
from _utils._common_utils_for_test import create_temp_dir, create_temp_files, get_name, reset_after_n_next_calls
28+
29+
from torch.utils.data import DataLoader
2530
from torchdata.datapipes.iter import (
2631
Bz2FileLoader,
2732
CSVDictParser,
@@ -34,9 +39,11 @@
3439
IoPathFileOpener,
3540
IoPathSaver,
3641
IterableWrapper,
42+
IterDataPipe,
3743
JsonParser,
3844
RarArchiveLoader,
3945
Saver,
46+
StreamReader,
4047
TarArchiveLoader,
4148
WebDataset,
4249
XzFileLoader,
@@ -77,6 +84,14 @@ def init_fn(worker_id):
7784
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
7885

7986

87+
def _unbatch(x):
88+
return x[0]
89+
90+
91+
def _noop(x):
92+
return x
93+
94+
8095
class TestDataPipeLocalIO(expecttest.TestCase):
8196
def setUp(self):
8297
self.temp_dir = create_temp_dir()
@@ -606,6 +621,31 @@ def _write_text_files(self):
606621
saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
607622
list(saver_dp)
608623

624+
@staticmethod
625+
def _slow_fn(tmpdirname, x):
626+
with open(os.path.join(tmpdirname, str(os.getpid())), "w") as pid_fh:
627+
pid_fh.write("anything")
628+
time.sleep(2)
629+
return (x, "str")
630+
631+
def test_disk_cache_locks(self):
632+
with tempfile.TemporaryDirectory() as tmpdirname:
633+
file_name = os.path.join(tmpdirname, "test.bin")
634+
dp = IterableWrapper([file_name])
635+
dp = dp.on_disk_cache(filepath_fn=_noop)
636+
dp = dp.map(functools.partial(self._slow_fn, tmpdirname))
637+
dp = dp.end_caching(mode="t", filepath_fn=_noop, timeout=120)
638+
dp = FileOpener(dp)
639+
dp = StreamReader(dp)
640+
dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch)
641+
result = list(dl)
642+
all_files = []
643+
for (_, _, filenames) in os.walk(tmpdirname):
644+
all_files += filenames
645+
# We expect only two files, one with pid and 'downloaded' one
646+
self.assertEqual(2, len(all_files))
647+
self.assertEqual("str", result[0][1])
648+
609649
# TODO(120): this test currently only covers reading from local
610650
# filesystem. It needs to be modified once test data can be stored on
611651
# gdrive/s3/onedrive

test/test_remote_io.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
import torchdata
1515

16-
from _utils._common_utils_for_test import check_hash_fn, create_temp_dir
16+
from _utils._common_utils_for_test import check_hash_fn, create_temp_dir, IS_WINDOWS
17+
from torch.utils.data import DataLoader
1718

1819
from torchdata.datapipes.iter import (
1920
EndOnDiskCacheHolder,
@@ -143,8 +144,9 @@ def _read_and_decode(x):
143144

144145
cached_it = iter(file_cache_dp)
145146
for expected_csv_path in _gen_filepath_fn(expected_file_name):
146-
# File doesn't exist on disk
147-
self.assertFalse(os.path.exists(expected_csv_path))
147+
148+
# Check disabled due to some elements of prefetching inside of on_disck_cache
149+
# self.assertFalse(os.path.exists(expected_csv_path))
148150

149151
csv_path = next(cached_it)
150152

@@ -167,15 +169,22 @@ def _read_and_decode(x):
167169
cached_it = iter(file_cache_dp)
168170
for i in range(3):
169171
expected_csv_path = os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")
172+
170173
# File doesn't exist on disk
171-
self.assertFalse(os.path.exists(expected_csv_path))
174+
# Check disabled due to some elements of prefetching inside of on_disck_cache
175+
# self.assertFalse(os.path.exists(expected_csv_path))
172176

173177
csv_path = next(cached_it)
174178

175179
# File is cached to disk
176180
self.assertTrue(os.path.exists(expected_csv_path))
177181
self.assertEqual(expected_csv_path, csv_path)
178182

183+
if not IS_WINDOWS:
184+
dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1)
185+
expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3
186+
self.assertEqual(sorted(expected), sorted(list(dl)))
187+
179188
def test_s3_io_iterdatapipe(self):
180189
# sanity test
181190
file_urls = ["s3://ai2-public-datasets"]

torchdata/datapipes/iter/util/cacheholder.py

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import functools
78
import hashlib
89
import inspect
910
import os.path
1011
import sys
12+
import time
13+
import warnings
1114

1215
from collections import deque
1316
from functools import partial
14-
from typing import Callable, Deque, Dict, Iterator, Optional, TypeVar
17+
from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, TypeVar
18+
19+
import portalocker
1520

1621
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
1722

@@ -26,6 +31,9 @@
2631

2732
T_co = TypeVar("T_co", covariant=True)
2833

34+
PROMISE_FILE_DELETE_TIMEOUT = 30
35+
PROMISE_FILE_DELETE_RETRY_INTERVAL = 0.005
36+
2937

3038
@functional_datapipe("in_memory_cache")
3139
class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
@@ -106,6 +114,9 @@ def _hash_check(filepath, hash_dict, hash_type):
106114
else:
107115
hash_func = hashlib.md5()
108116

117+
# with portalocker.Lock(filepath, "rb", flags=portalocker.LockFlags.SHARED) as f:
118+
# TODO(VitalyFedyunin): Line above will require all readers (Win) to obtain proper locks,
119+
# I'm putting it on hold as we need to modify PyTorch core codebase heavily.
109120
with open(filepath, "rb") as f:
110121
chunk = f.read(1024 ** 2)
111122
while chunk:
@@ -115,6 +126,10 @@ def _hash_check(filepath, hash_dict, hash_type):
115126
return hash_func.hexdigest() == hash_dict[filepath]
116127

117128

129+
def _promise_filename(filename):
130+
return filename + ".promise"
131+
132+
118133
@functional_datapipe("on_disk_cache")
119134
class OnDiskCacheHolderIterDataPipe(IterDataPipe):
120135
"""
@@ -145,7 +160,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
145160
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
146161
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
147162
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
148-
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
163+
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
149164
"""
150165

151166
_temp_dict: Dict = {}
@@ -184,22 +199,42 @@ def __add__(self, other_datapipe):
184199
@staticmethod
185200
def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn):
186201
filepaths = data if filepath_fn is None else filepath_fn(data)
202+
result = True
187203
if not isinstance(filepaths, (list, tuple)):
188204
filepaths = [
189205
filepaths,
190206
]
191207

192208
for filepath in filepaths:
209+
cached_file_exists = True
193210
if not os.path.exists(filepath):
194-
return False
195-
196-
if hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
197-
return False
198-
199-
if extra_check_fn is not None and not extra_check_fn(filepath):
200-
return False
201-
202-
return True
211+
cached_file_exists = False
212+
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
213+
cached_file_exists = False
214+
elif extra_check_fn is not None and not extra_check_fn(filepath):
215+
cached_file_exists = False
216+
217+
if not cached_file_exists:
218+
promise_filepath = _promise_filename(filepath)
219+
dirname = os.path.dirname(promise_filepath)
220+
if not os.path.exists(dirname):
221+
os.makedirs(dirname)
222+
223+
with portalocker.Lock(promise_filepath, "a+", flags=portalocker.LockFlags.EXCLUSIVE) as promise_fh:
224+
promise_fh.seek(0)
225+
data = promise_fh.read()
226+
# TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we
227+
# need to somehow propagate uniq session id for dataloader, save and compare it here,
228+
# raising error
229+
file_exists = len(data) > 0
230+
if not file_exists:
231+
result = False
232+
promise_fh.seek(0)
233+
promise_fh.write("[dataloader session uid]")
234+
promise_fh.truncate()
235+
promise_fh.flush()
236+
237+
return result
203238

204239
def _end_caching(self):
205240
filepath_fn, hash_dict, hash_type, extra_check_fn = OnDiskCacheHolderIterDataPipe._temp_dict.pop(self)
@@ -232,6 +267,82 @@ def _read_str(fd):
232267
return "".join(fd)
233268

234269

270+
def _find_promise_file(filename):
271+
promise_filename = _promise_filename(filename)
272+
while not os.path.exists(promise_filename):
273+
dirname = os.path.dirname(promise_filename)
274+
if dirname == os.path.dirname(dirname):
275+
promise_filename = _promise_filename(filename)
276+
break
277+
promise_filename = _promise_filename(dirname)
278+
return promise_filename
279+
280+
281+
def _is_promise_pending(promise_filename):
282+
return os.path.exists(promise_filename)
283+
284+
285+
def _wait_promise_fn(timeout, filename):
286+
promise_filename = _find_promise_file(filename)
287+
start = time.time()
288+
while _is_promise_pending(promise_filename):
289+
time.sleep(0.01)
290+
if time.time() - start > timeout:
291+
raise Exception(
292+
f"OnDiskCache Exception: {filename} expected to be written by different process, "
293+
+ f"but file is not ready in {timeout} seconds."
294+
)
295+
return filename
296+
297+
298+
class _FulfilledPromisesIterDataPipe(IterDataPipe):
299+
def __init__(self, source_datapipe):
300+
self.source_datapipe = source_datapipe
301+
302+
@staticmethod
303+
def _del_promise_file(promise_filename, filename):
304+
if os.path.exists(promise_filename):
305+
retry = True
306+
start = time.time()
307+
while retry:
308+
retry = False
309+
try:
310+
os.unlink(promise_filename)
311+
except Exception as e:
312+
# Workaround about Windows not letting to delete file, while it is open by another process
313+
retry = True
314+
if time.time() - start > PROMISE_FILE_DELETE_TIMEOUT:
315+
raise Exception("Timeout while trying to recover from the ", type(e), e)
316+
time.sleep(PROMISE_FILE_DELETE_RETRY_INTERVAL)
317+
else:
318+
warnings.warn(
319+
f"Attempt to mark {promise_filename} promise (base of file {filename}) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache."
320+
)
321+
322+
def __iter__(self):
323+
old_promise_filename = None
324+
old_filename = None
325+
first_entry = True
326+
# TODO(VitalyFedyunin): Limit buffer size here. It is only contains file names from archive,
327+
# but better be save than sorry.
328+
buffer: List[Any] = []
329+
for filename in self.source_datapipe:
330+
promise_filename = _find_promise_file(filename)
331+
if not first_entry:
332+
buffer.append(old_filename)
333+
if old_promise_filename != promise_filename:
334+
self._del_promise_file(old_promise_filename, old_filename)
335+
yield from buffer
336+
buffer = []
337+
old_promise_filename = promise_filename
338+
old_filename = filename
339+
first_entry = False
340+
if not first_entry:
341+
buffer.append(old_filename)
342+
self._del_promise_file(old_promise_filename, old_filename)
343+
yield from buffer
344+
345+
235346
@functional_datapipe("end_caching")
236347
class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
237348
"""
@@ -248,6 +359,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
248359
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``.
249360
skip_read: Boolean value to skip reading the file handle from ``datapipe``.
250361
By default, reading is enabled and reading function is created based on the ``mode``.
362+
timeout: Integer value of seconds to wait for uncached item to be written to disk
251363
252364
Example:
253365
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
@@ -259,10 +371,10 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
259371
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
260372
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
261373
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
262-
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
374+
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
263375
"""
264376

265-
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
377+
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False, timeout=300):
266378
if filepath_fn is not None and same_filepath_fn:
267379
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`")
268380

@@ -276,6 +388,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
276388

277389
_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
278390
cached_dp = cache_holder._end_caching()
391+
cached_dp = cached_dp.map(functools.partial(_wait_promise_fn, timeout))
279392
cached_dp = FileLister(cached_dp, recursive=True)
280393

281394
if same_filepath_fn:
@@ -297,6 +410,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
297410
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type)
298411

299412
todo_dp = todo_dp.save_to_disk(mode=mode)
413+
todo_dp = _FulfilledPromisesIterDataPipe(todo_dp)
300414

301415
return cached_dp.concat(todo_dp)
302416

torchdata/datapipes/iter/util/saver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __iter__(self) -> Iterator[str]:
5656
dirname = os.path.dirname(filepath)
5757
if not os.path.exists(dirname):
5858
os.makedirs(dirname)
59+
# with portalocker.Lock(filepath, self.mode, flags=portalocker.LockFlags.EXCLUSIVE) as f:
60+
# TODO(VitalyFedyunin): Enabling line above will require all read sites to be updated (Win).
5961
with open(filepath, self.mode) as f:
6062
f.write(data)
6163
yield filepath

0 commit comments

Comments
 (0)