4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import functools
7
8
import hashlib
8
9
import inspect
9
10
import os .path
10
11
import sys
12
+ import time
13
+ import warnings
11
14
12
15
from collections import deque
13
16
from functools import partial
14
- from typing import Callable , Deque , Dict , Iterator , Optional , TypeVar
17
+ from typing import Any , Callable , Deque , Dict , Iterator , List , Optional , TypeVar
18
+
19
+ import portalocker
15
20
16
21
from torch .utils .data .datapipes .utils .common import _check_lambda_fn , DILL_AVAILABLE
17
22
26
31
27
32
T_co = TypeVar ("T_co" , covariant = True )
28
33
34
+ PROMISE_FILE_DELETE_TIMEOUT = 30
35
+ PROMISE_FILE_DELETE_RETRY_INTERVAL = 0.005
36
+
29
37
30
38
@functional_datapipe ("in_memory_cache" )
31
39
class InMemoryCacheHolderIterDataPipe (IterDataPipe [T_co ]):
@@ -106,6 +114,9 @@ def _hash_check(filepath, hash_dict, hash_type):
106
114
else :
107
115
hash_func = hashlib .md5 ()
108
116
117
+ # with portalocker.Lock(filepath, "rb", flags=portalocker.LockFlags.SHARED) as f:
118
+ # TODO(VitalyFedyunin): Line above will require all readers (Win) to obtain proper locks,
119
+ # I'm putting it on hold as we need to modify PyTorch core codebase heavily.
109
120
with open (filepath , "rb" ) as f :
110
121
chunk = f .read (1024 ** 2 )
111
122
while chunk :
@@ -115,6 +126,10 @@ def _hash_check(filepath, hash_dict, hash_type):
115
126
return hash_func .hexdigest () == hash_dict [filepath ]
116
127
117
128
129
+ def _promise_filename (filename ):
130
+ return filename + ".promise"
131
+
132
+
118
133
@functional_datapipe ("on_disk_cache" )
119
134
class OnDiskCacheHolderIterDataPipe (IterDataPipe ):
120
135
"""
@@ -145,7 +160,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
145
160
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
146
161
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
147
162
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
148
- >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
163
+ >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
149
164
"""
150
165
151
166
_temp_dict : Dict = {}
@@ -184,22 +199,42 @@ def __add__(self, other_datapipe):
184
199
@staticmethod
185
200
def _cache_check_fn (data , filepath_fn , hash_dict , hash_type , extra_check_fn ):
186
201
filepaths = data if filepath_fn is None else filepath_fn (data )
202
+ result = True
187
203
if not isinstance (filepaths , (list , tuple )):
188
204
filepaths = [
189
205
filepaths ,
190
206
]
191
207
192
208
for filepath in filepaths :
209
+ cached_file_exists = True
193
210
if not os .path .exists (filepath ):
194
- return False
195
-
196
- if hash_dict is not None and not _hash_check (filepath , hash_dict , hash_type ):
197
- return False
198
-
199
- if extra_check_fn is not None and not extra_check_fn (filepath ):
200
- return False
201
-
202
- return True
211
+ cached_file_exists = False
212
+ elif hash_dict is not None and not _hash_check (filepath , hash_dict , hash_type ):
213
+ cached_file_exists = False
214
+ elif extra_check_fn is not None and not extra_check_fn (filepath ):
215
+ cached_file_exists = False
216
+
217
+ if not cached_file_exists :
218
+ promise_filepath = _promise_filename (filepath )
219
+ dirname = os .path .dirname (promise_filepath )
220
+ if not os .path .exists (dirname ):
221
+ os .makedirs (dirname )
222
+
223
+ with portalocker .Lock (promise_filepath , "a+" , flags = portalocker .LockFlags .EXCLUSIVE ) as promise_fh :
224
+ promise_fh .seek (0 )
225
+ data = promise_fh .read ()
226
+ # TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we
227
+ # need to somehow propagate uniq session id for dataloader, save and compare it here,
228
+ # raising error
229
+ file_exists = len (data ) > 0
230
+ if not file_exists :
231
+ result = False
232
+ promise_fh .seek (0 )
233
+ promise_fh .write ("[dataloader session uid]" )
234
+ promise_fh .truncate ()
235
+ promise_fh .flush ()
236
+
237
+ return result
203
238
204
239
def _end_caching (self ):
205
240
filepath_fn , hash_dict , hash_type , extra_check_fn = OnDiskCacheHolderIterDataPipe ._temp_dict .pop (self )
@@ -232,6 +267,82 @@ def _read_str(fd):
232
267
return "" .join (fd )
233
268
234
269
270
+ def _find_promise_file (filename ):
271
+ promise_filename = _promise_filename (filename )
272
+ while not os .path .exists (promise_filename ):
273
+ dirname = os .path .dirname (promise_filename )
274
+ if dirname == os .path .dirname (dirname ):
275
+ promise_filename = _promise_filename (filename )
276
+ break
277
+ promise_filename = _promise_filename (dirname )
278
+ return promise_filename
279
+
280
+
281
+ def _is_promise_pending (promise_filename ):
282
+ return os .path .exists (promise_filename )
283
+
284
+
285
+ def _wait_promise_fn (timeout , filename ):
286
+ promise_filename = _find_promise_file (filename )
287
+ start = time .time ()
288
+ while _is_promise_pending (promise_filename ):
289
+ time .sleep (0.01 )
290
+ if time .time () - start > timeout :
291
+ raise Exception (
292
+ f"OnDiskCache Exception: { filename } expected to be written by different process, "
293
+ + f"but file is not ready in { timeout } seconds."
294
+ )
295
+ return filename
296
+
297
+
298
+ class _FulfilledPromisesIterDataPipe (IterDataPipe ):
299
+ def __init__ (self , source_datapipe ):
300
+ self .source_datapipe = source_datapipe
301
+
302
+ @staticmethod
303
+ def _del_promise_file (promise_filename , filename ):
304
+ if os .path .exists (promise_filename ):
305
+ retry = True
306
+ start = time .time ()
307
+ while retry :
308
+ retry = False
309
+ try :
310
+ os .unlink (promise_filename )
311
+ except Exception as e :
312
+ # Workaround about Windows not letting to delete file, while it is open by another process
313
+ retry = True
314
+ if time .time () - start > PROMISE_FILE_DELETE_TIMEOUT :
315
+ raise Exception ("Timeout while trying to recover from the " , type (e ), e )
316
+ time .sleep (PROMISE_FILE_DELETE_RETRY_INTERVAL )
317
+ else :
318
+ warnings .warn (
319
+ f"Attempt to mark { promise_filename } promise (base of file { filename } ) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache."
320
+ )
321
+
322
+ def __iter__ (self ):
323
+ old_promise_filename = None
324
+ old_filename = None
325
+ first_entry = True
326
+ # TODO(VitalyFedyunin): Limit buffer size here. It is only contains file names from archive,
327
+ # but better be save than sorry.
328
+ buffer : List [Any ] = []
329
+ for filename in self .source_datapipe :
330
+ promise_filename = _find_promise_file (filename )
331
+ if not first_entry :
332
+ buffer .append (old_filename )
333
+ if old_promise_filename != promise_filename :
334
+ self ._del_promise_file (old_promise_filename , old_filename )
335
+ yield from buffer
336
+ buffer = []
337
+ old_promise_filename = promise_filename
338
+ old_filename = filename
339
+ first_entry = False
340
+ if not first_entry :
341
+ buffer .append (old_filename )
342
+ self ._del_promise_file (old_promise_filename , old_filename )
343
+ yield from buffer
344
+
345
+
235
346
@functional_datapipe ("end_caching" )
236
347
class EndOnDiskCacheHolderIterDataPipe (IterDataPipe ):
237
348
"""
@@ -248,6 +359,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
248
359
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``.
249
360
skip_read: Boolean value to skip reading the file handle from ``datapipe``.
250
361
By default, reading is enabled and reading function is created based on the ``mode``.
362
+ timeout: Integer value of seconds to wait for uncached item to be written to disk
251
363
252
364
Example:
253
365
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
@@ -259,10 +371,10 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
259
371
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
260
372
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
261
373
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
262
- >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
374
+ >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
263
375
"""
264
376
265
- def __new__ (cls , datapipe , mode = "wb" , filepath_fn = None , * , same_filepath_fn = False , skip_read = False ):
377
+ def __new__ (cls , datapipe , mode = "wb" , filepath_fn = None , * , same_filepath_fn = False , skip_read = False , timeout = 300 ):
266
378
if filepath_fn is not None and same_filepath_fn :
267
379
raise ValueError ("`filepath_fn` is mutually exclusive with `same_filepath_fn`" )
268
380
@@ -276,6 +388,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
276
388
277
389
_filepath_fn , _hash_dict , _hash_type , _ = OnDiskCacheHolderIterDataPipe ._temp_dict [cache_holder ]
278
390
cached_dp = cache_holder ._end_caching ()
391
+ cached_dp = cached_dp .map (functools .partial (_wait_promise_fn , timeout ))
279
392
cached_dp = FileLister (cached_dp , recursive = True )
280
393
281
394
if same_filepath_fn :
@@ -297,6 +410,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
297
410
todo_dp = todo_dp .check_hash (_hash_dict , _hash_type )
298
411
299
412
todo_dp = todo_dp .save_to_disk (mode = mode )
413
+ todo_dp = _FulfilledPromisesIterDataPipe (todo_dp )
300
414
301
415
return cached_dp .concat (todo_dp )
302
416
0 commit comments