diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ceacd6ca0b7..99a7d33226b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -57,6 +57,11 @@ Enhancements - Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports "dayofyear" and "dayofweek" accessors (:issue:`2597`). By `Spencer Clark `_. +- The option ``'warn_for_unclosed_files'`` (False by default) has been added to + allow users to enable a warning when files opened by xarray are deallocated + but were not explicitly closed. This is mostly useful for debugging; we + recommend enabling it in your test suites if you use xarray for IO. + By `Stephan Hoyer `_ - Support Dask ``HighLevelGraphs`` by `Matthew Rocklin `_. - :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` now supports the ``loffset`` kwarg just like Pandas. @@ -68,6 +73,12 @@ Enhancements Bug fixes ~~~~~~~~~ +- Ensure files are automatically closed, if possible, when no longer referenced + by a Python variable (:issue:`2560`). + By `Stephan Hoyer `_ +- Fixed possible race conditions when reading/writing to disk in parallel + (:issue:`2595`). + By `Stephan Hoyer `_ - Fix h5netcdf saving scalars with filters or chunks (:issue:`2563`). By `Martin Raspaud `_. - Fix parsing of ``_Unsigned`` attribute set by OPENDAP servers. (:issue:`2583`). diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index a93285370b2..6362842dd42 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,7 +1,10 @@ +import contextlib import threading +import warnings from ..core import utils from ..core.options import OPTIONS +from .locks import acquire from .lru_cache import LRUCache @@ -11,6 +14,8 @@ assert FILE_CACHE.maxsize, 'file cache must be at least size one' +REF_COUNTS = {} + _DEFAULT_MODE = utils.ReprObject('') @@ -22,7 +27,7 @@ class FileManager(object): many open files and transferring them between multiple processes. """ - def acquire(self): + def acquire(self, needs_lock=True): """Acquire the file object from this manager.""" raise NotImplementedError @@ -62,6 +67,9 @@ class CachingFileManager(FileManager): def __init__(self, opener, *args, **keywords): """Initialize a FileManager. + The cache and ref_counts arguments exist solely to facilitate + dependency injection, and should only be set for tests. + Parameters ---------- opener : callable @@ -90,6 +98,9 @@ def __init__(self, opener, *args, **keywords): global variable and contains non-picklable file objects, an unpickled FileManager objects will be restored with the default cache. + ref_counts : dict, optional + Optional dict to use for keeping track the number of references to + the same file. """ # TODO: replace with real keyword arguments when we drop Python 2 # support @@ -97,6 +108,7 @@ def __init__(self, opener, *args, **keywords): kwargs = keywords.pop('kwargs', None) lock = keywords.pop('lock', None) cache = keywords.pop('cache', FILE_CACHE) + ref_counts = keywords.pop('ref_counts', REF_COUNTS) if keywords: raise TypeError('FileManager() got unexpected keyword arguments: ' '%s' % list(keywords)) @@ -105,34 +117,52 @@ def __init__(self, opener, *args, **keywords): self._args = args self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False self._lock = threading.Lock() if self._default_lock else lock + + # cache[self._key] stores the file associated with this object. self._cache = cache self._key = self._make_key() + # ref_counts[self._key] stores the number of CachingFileManager objects + # in memory referencing this same file. We use this to know if we can + # close a file when the manager is deallocated. + self._ref_counter = _RefCounter(ref_counts) + self._ref_counter.increment(self._key) + def _make_key(self): """Make a key for caching files in the LRU cache.""" value = (self._opener, self._args, - self._mode, + 'a' if self._mode == 'w' else self._mode, tuple(sorted(self._kwargs.items()))) return _HashedSequence(value) - def acquire(self): + @contextlib.contextmanager + def _optional_lock(self, needs_lock): + """Context manager for optionally acquiring a lock.""" + if needs_lock: + with self._lock: + yield + else: + yield + + def acquire(self, needs_lock=True): """Acquiring a file object from the manager. A new file is only opened if it has expired from the least-recently-used cache. - This method uses a reentrant lock, which ensures that it is - thread-safe. You can safely acquire a file in multiple threads at the - same time, as long as the underlying file object is thread-safe. + This method uses a lock, which ensures that it is thread-safe. You can + safely acquire a file in multiple threads at the same time, as long as + the underlying file object is thread-safe. Returns ------- An open file object, as returned by ``opener(*args, **kwargs)``. """ - with self._lock: + with self._optional_lock(needs_lock): try: file = self._cache[self._key] except KeyError: @@ -144,28 +174,53 @@ def acquire(self): if self._mode == 'w': # ensure file doesn't get overriden when opened again self._mode = 'a' - self._key = self._make_key() self._cache[self._key] = file return file - def _close(self): - default = None - file = self._cache.pop(self._key, default) - if file is not None: - file.close() - def close(self, needs_lock=True): """Explicitly close any associated file object (if necessary).""" # TODO: remove needs_lock if/when we have a reentrant lock in # dask.distributed: https://github.com/dask/dask/issues/3832 - if needs_lock: - with self._lock: - self._close() - else: - self._close() + with self._optional_lock(needs_lock): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def __del__(self): + # If we're the only CachingFileManger referencing a unclosed file, we + # should remove it from the cache upon garbage collection. + # + # Keeping our own count of file references might seem like overkill, + # but it's actually pretty common to reopen files with the same + # variable name in a notebook or command line environment, e.g., to + # fix the parameters used when opening a file: + # >>> ds = xarray.open_dataset('myfile.nc') + # >>> ds = xarray.open_dataset('myfile.nc', decode_times=False) + # This second assignment to "ds" drops CPython's ref-count on the first + # "ds" argument to zero, which can trigger garbage collections. So if + # we didn't check whether another object is referencing 'myfile.nc', + # the newly opened file would actually be immediately closed! + ref_count = self._ref_counter.decrement(self._key) + + if not ref_count and self._key in self._cache: + if acquire(self._lock, blocking=False): + # Only close files if we can do so immediately. + try: + self.close(needs_lock=False) + finally: + self._lock.release() + + if OPTIONS['warn_for_unclosed_files']: + warnings.warn( + 'deallocating {}, but file is not already closed. ' + 'This may indicate a bug.' + .format(self), RuntimeWarning, stacklevel=2) def __getstate__(self): """State for pickling.""" + # cache and ref_counts are intentionally omitted: we don't want to try + # to serialize these global objects. lock = None if self._default_lock else self._lock return (self._opener, self._args, self._mode, self._kwargs, lock) @@ -174,6 +229,34 @@ def __setstate__(self, state): opener, args, mode, kwargs, lock = state self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + def __repr__(self): + args_string = ', '.join(map(repr, self._args)) + if self._mode is not _DEFAULT_MODE: + args_string += ', mode={!r}'.format(self._mode) + return '{}({!r}, {}, kwargs={})'.format( + type(self).__name__, self._opener, args_string, self._kwargs) + + +class _RefCounter(object): + """Class for keeping track of reference counts.""" + def __init__(self, counts): + self._counts = counts + self._lock = threading.Lock() + + def increment(self, name): + with self._lock: + count = self._counts[name] = self._counts.get(name, 0) + 1 + return count + + def decrement(self, name): + with self._lock: + count = self._counts[name] - 1 + if count: + self._counts[name] = count + else: + del self._counts[name] + return count + class _HashedSequence(list): """Speedup repeated look-ups by caching hash values. @@ -198,7 +281,8 @@ class DummyFileManager(FileManager): def __init__(self, value): self._value = value - def acquire(self): + def acquire(self, needs_lock=True): + del needs_lock # ignored return self._value def close(self, needs_lock=True): diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 90f63e88cde..0564df5b167 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -26,8 +26,8 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - array = self.get_array() with self.datastore.lock: + array = self.get_array(needs_lock=False) return array[key] @@ -230,9 +230,6 @@ def prepare_variable(self, name, variable, check_encoding=False, def sync(self): self.ds.sync() - # if self.autoclose: - # self.close() - # super(H5NetCDFStore, self).sync(compute=compute) def close(self, **kwargs): self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index f633280ef1d..6c135fd1240 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -8,6 +8,11 @@ # no need to worry about serializing the lock SerializableLock = threading.Lock +try: + from dask.distributed import Lock as DistributedLock +except ImportError: + DistributedLock = None + # Locks used by multiple backends. # Neither HDF5 nor the netCDF-C library are thread-safe. @@ -33,16 +38,11 @@ def _get_multiprocessing_lock(key): return multiprocessing.Lock() -def _get_distributed_lock(key): - from dask.distributed import Lock - return Lock(key) - - _LOCK_MAKERS = { None: _get_threaded_lock, 'threaded': _get_threaded_lock, 'multiprocessing': _get_multiprocessing_lock, - 'distributed': _get_distributed_lock, + 'distributed': DistributedLock, } @@ -113,6 +113,27 @@ def get_write_lock(key): return lock_maker(key) +def acquire(lock, blocking=True): + """Acquire a lock, possibly in a non-blocking fashion. + + Includes backwards compatibility hacks for old versions of Python, dask + and dask-distributed. + """ + if blocking: + # no arguments needed + return lock.acquire() + elif DistributedLock is not None and isinstance(lock, DistributedLock): + # distributed.Lock doesn't support the blocking argument yet: + # https://github.com/dask/distributed/pull/2412 + return lock.acquire(timeout=0) + else: + # "blocking" keyword argument not supported for: + # - threading.Lock on Python 2. + # - dask.SerializableLock with dask v1.0.0 or earlier. + # - multiprocessing.Lock calls the argument "block" instead. + return lock.acquire(blocking) + + class CombinedLock(object): """A combination of multiple locks. @@ -123,12 +144,12 @@ class CombinedLock(object): def __init__(self, locks): self.locks = tuple(set(locks)) # remove duplicates - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) + def acquire(self, blocking=True): + return all(acquire(lock, blocking=blocking) for lock in self.locks) - def release(self, *args): + def release(self): for lock in self.locks: - lock.release(*args) + lock.release() def __enter__(self): for lock in self.locks: @@ -138,7 +159,6 @@ def __exit__(self, *args): for lock in self.locks: lock.__exit__(*args) - @property def locked(self): return any(lock.locked for lock in self.locks) @@ -149,10 +169,10 @@ def __repr__(self): class DummyLock(object): """DummyLock provides the lock API without any actual locking.""" - def acquire(self, *args): + def acquire(self, blocking=True): pass - def release(self, *args): + def release(self): pass def __enter__(self): @@ -161,7 +181,6 @@ def __enter__(self): def __exit__(self, *args): pass - @property def locked(self): return False diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 08ba085b77e..2dc692e8724 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -48,13 +48,14 @@ def __init__(self, variable_name, datastore): def __setitem__(self, key, value): with self.datastore.lock: - data = self.get_array() + data = self.get_array(needs_lock=False) data[key] = value if self.datastore.autoclose: self.datastore.close(needs_lock=False) - def get_array(self): - return self.datastore.ds.variables[self.variable_name] + def get_array(self, needs_lock=True): + ds = self.datastore._manager.acquire(needs_lock).value + return ds.variables[self.variable_name] class NetCDF4ArrayWrapper(BaseNetCDF4Array): @@ -69,10 +70,9 @@ def _getitem(self, key): else: getitem = operator.getitem - original_array = self.get_array() - try: with self.datastore.lock: + original_array = self.get_array(needs_lock=False) array = getitem(original_array, key) except IndexError: # Catch IndexError in netCDF4 and return a more informative diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 606ed5251ac..41bc256835a 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -24,8 +24,9 @@ def __init__(self, variable_name, datastore): self.shape = array.shape self.dtype = np.dtype(array.dtype) - def get_array(self): - return self.datastore.ds.variables[self.variable_name] + def get_array(self, needs_lock=True): + ds = self.datastore._manager.acquire(needs_lock) + return ds.variables[self.variable_name] def __getitem__(self, key): return indexing.explicit_indexing_adapter( @@ -33,8 +34,8 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - array = self.get_array() with self.datastore.lock: + array = self.get_array(needs_lock=False) return array[key] diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 574fff744e3..b171192ed6a 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -26,18 +26,21 @@ def __init__(self, variable_name, datastore): self.shape = array.shape self.dtype = np.dtype(array.typecode()) - def get_array(self): - return self.datastore.ds.variables[self.variable_name] + def get_array(self, needs_lock=True): + ds = self.datastore._manager.acquire(needs_lock) + return ds.variables[self.variable_name] def __getitem__(self, key): return indexing.explicit_indexing_adapter( key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - array = self.get_array() with self.datastore.lock: + array = self.get_array(needs_lock=False) + if key == () and self.ndim == 0: return array.get_value() + return array[key] diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 27655438cc3..24874d63f93 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -22,13 +22,16 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, manager, vrt_params=None): + + def __init__(self, manager, lock, vrt_params=None): from rasterio.vrt import WarpedVRT self.manager = manager + self.lock = lock # cannot save riods as an attribute: this would break pickleability riods = manager.acquire() - riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params) + if vrt_params is not None: + riods = WarpedVRT(riods, **vrt_params) self.vrt_params = vrt_params self._shape = (riods.count, riods.height, riods.width) @@ -112,9 +115,11 @@ def _getitem(self, key): stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: - riods = self.manager.acquire() - riods = riods if self.vrt_params is None else WarpedVRT(riods,**self.vrt_params) - out = riods.read(band_key, window=window) + with self.lock: + riods = self.manager.acquire(needs_lock=False) + if self.vrt_params is not None: + riods = WarpedVRT(riods, **self.vrt_params) + out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) @@ -221,9 +226,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, tolerance=vrt.tolerance, warp_extras=vrt.warp_extras) - manager = CachingFileManager(rasterio.open, filename, mode='r') + if lock is None: + lock = RASTERIO_LOCK + + manager = CachingFileManager(rasterio.open, filename, lock=lock, mode='r') riods = manager.acquire() - riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params) + if vrt_params is not None: + riods = WarpedVRT(riods, **vrt_params) if cache is None: cache = chunks is None @@ -303,7 +312,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager, vrt_params)) + data = indexing.LazilyOuterIndexedArray( + RasterioArrayWrapper(manager, lock, vrt_params)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) @@ -323,10 +333,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, mtime = None token = tokenize(filename, mtime, chunks) name_prefix = 'open_rasterio-%s' % token - if lock is None: - lock = RASTERIO_LOCK - result = result.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable result._file_obj = manager diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index b009342efb6..157ae44f547 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,7 +11,7 @@ from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict from .common import BackendArray, WritableCFDataStore -from .locks import get_write_lock +from .locks import ensure_lock, get_write_lock from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -35,16 +35,17 @@ class ScipyArrayWrapper(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore self.variable_name = variable_name - array = self.get_array() + array = self.get_variable().data self.shape = array.shape self.dtype = np.dtype(array.dtype.kind + str(array.dtype.itemsize)) - def get_array(self): - return self.datastore.ds.variables[self.variable_name].data + def get_variable(self, needs_lock=True): + ds = self.datastore._manager.acquire(needs_lock) + return ds.variables[self.variable_name] def __getitem__(self, key): - data = NumpyIndexingAdapter(self.get_array())[key] + data = NumpyIndexingAdapter(self.get_variable().data)[key] # Copy data if the source file is mmapped. This makes things consistent # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. @@ -52,15 +53,16 @@ def __getitem__(self, key): return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + with self.datastore.lock: + data = self.get_variable(needs_lock=False) + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -140,6 +142,8 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, isinstance(filename_or_obj, basestring)): lock = get_write_lock(filename_or_obj) + self.lock = ensure_lock(lock) + if isinstance(filename_or_obj, basestring): manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, diff --git a/xarray/core/options.py b/xarray/core/options.py index ab461ca86bc..400508a5d59 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,6 +6,7 @@ ARITHMETIC_JOIN = 'arithmetic_join' ENABLE_CFTIMEINDEX = 'enable_cftimeindex' FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +WARN_FOR_UNCLOSED_FILES = 'warn_for_unclosed_files' CMAP_SEQUENTIAL = 'cmap_sequential' CMAP_DIVERGENT = 'cmap_divergent' KEEP_ATTRS = 'keep_attrs' @@ -16,6 +17,7 @@ ARITHMETIC_JOIN: 'inner', ENABLE_CFTIMEINDEX: True, FILE_CACHE_MAXSIZE: 128, + WARN_FOR_UNCLOSED_FILES: False, CMAP_SEQUENTIAL: 'viridis', CMAP_DIVERGENT: 'RdBu_r', KEEP_ATTRS: 'default' @@ -33,6 +35,7 @@ def _positive_integer(value): ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), FILE_CACHE_MAXSIZE: _positive_integer, + WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), KEEP_ATTRS: lambda choice: choice in [True, False, 'default'] } @@ -63,7 +66,8 @@ def _get_keep_attrs(default): elif global_choice in [True, False]: return global_choice else: - raise ValueError("The global option keep_attrs must be one of True, False or 'default'.") + raise ValueError("The global option keep_attrs must be one of True, " + "False or 'default'.") class set_options(object): @@ -79,6 +83,9 @@ class set_options(object): global least-recently-usage cached. This should be smaller than your system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. Default: 128. + - ``warn_for_unclosed_files``: whether or not to issue a warning when + unclosed files are deallocated (default False). This is mostly useful + for debugging. - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index cd66ad82356..c57f6720810 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -13,6 +13,7 @@ import pytest from xarray.core import utils +from xarray.core.options import set_options from xarray.core.indexing import ExplicitlyIndexed from xarray.testing import (assert_equal, assert_identical, # noqa: F401 assert_allclose, assert_combined_tile_ids_equal) @@ -88,12 +89,6 @@ def LooseVersion(vstring): not has_cftime_or_netCDF4, reason='requires cftime or netCDF4') if not has_pathlib: has_pathlib, requires_pathlib = _importorskip('pathlib2') -if has_dask: - import dask - if LooseVersion(dask.__version__) < '0.18': - dask.set_options(get=dask.get) - else: - dask.config.set(scheduler='single-threaded') try: import_seaborn() has_seaborn = True @@ -102,6 +97,17 @@ def LooseVersion(vstring): requires_seaborn = pytest.mark.skipif(not has_seaborn, reason='requires seaborn') +# change some global options for tests +set_options(warn_for_unclosed_files=True) + +if has_dask: + import dask + if LooseVersion(dask.__version__) < '0.18': + dask.set_options(get=dask.get) + else: + dask.config.set(scheduler='single-threaded') + +# pytest config try: _SKIP_FLAKY = not pytest.config.getoption("--run-flaky") _SKIP_NETWORK_TESTS = not pytest.config.getoption("--run-network-tests") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3bc26c90ec0..f4d9154eadb 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -299,18 +299,23 @@ def test_pickle(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip( expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: - raw_pickle = pickle.dumps(roundtripped) - # windows doesn't like opening the same file twice - roundtripped.close() - unpickled_ds = pickle.loads(raw_pickle) - assert_identical(expected, unpickled_ds) + with roundtripped: + # Windows doesn't like reopening an already open file + raw_pickle = pickle.dumps(roundtripped) + with pickle.loads(raw_pickle) as unpickled_ds: + assert_identical(expected, unpickled_ds) + @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_pickle_dataarray(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip( expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: - unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) - assert_identical(expected['foo'], unpickled_array) + with roundtripped: + raw_pickle = pickle.dumps(roundtripped['foo']) + # TODO: figure out how to explicitly close the file for the + # unpickled DataArray? + unpickled = pickle.loads(raw_pickle) + assert_identical(expected['foo'], unpickled) def test_dataset_caching(self): expected = Dataset({'foo': ('x', [5, 6, 7])}) @@ -435,13 +440,13 @@ def test_roundtrip_float64_data(self): assert_identical(expected, actual) def test_roundtrip_example_1_netcdf(self): - expected = open_example_dataset('example_1.nc') - with self.roundtrip(expected) as actual: - # we allow the attributes to differ since that - # will depend on the encoding used. For example, - # without CF encoding 'actual' will end up with - # a dtype attribute. - assert_equal(expected, actual) + with open_example_dataset('example_1.nc') as expected: + with self.roundtrip(expected) as actual: + # we allow the attributes to differ since that + # will depend on the encoding used. For example, + # without CF encoding 'actual' will end up with + # a dtype attribute. + assert_equal(expected, actual) def test_roundtrip_coordinates(self): original = Dataset({'foo': ('x', [0, 1])}, @@ -1274,6 +1279,7 @@ def test_autoclose_future_warning(self): @requires_netCDF4 @requires_dask +@pytest.mark.filterwarnings('ignore:deallocating CachingFileManager') class TestNetCDF4ViaDaskData(TestNetCDF4Data): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -1659,9 +1665,9 @@ def test_roundtrip_example_1_netcdf_gz(self): def test_netcdf3_endianness(self): # regression test for GH416 - expected = open_example_dataset('bears.nc', engine='scipy') - for var in expected.variables.values(): - assert var.dtype.isnative + with open_example_dataset('bears.nc', engine='scipy') as expected: + for var in expected.variables.values(): + assert var.dtype.isnative @requires_netCDF4 def test_nc4_scipy(self): @@ -1979,13 +1985,13 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, subds.to_netcdf(tmpfiles[ii], engine=writeengine) # check that calculation on opened datasets works properly - actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - chunks=chunks) + with open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, + chunks=chunks) as actual: - # check that using open_mfdataset returns dask arrays for variables - assert isinstance(actual['foo'].data, dask_array_type) + # check that using open_mfdataset returns dask arrays for variables + assert isinstance(actual['foo'].data, dask_array_type) - assert_identical(original, actual) + assert_identical(original, actual) @requires_scipy_or_netCDF4 @@ -2032,20 +2038,17 @@ def gen_datasets_with_common_coord_and_time(self): return ds1, ds2 - def test_open_mfdataset_does_same_as_concat(self): - options = ['all', 'minimal', 'different', ] - + @pytest.mark.parametrize('opt', ['all', 'minimal', 'different']) + def test_open_mfdataset_does_same_as_concat(self, opt): with self.setup_files_and_datasets() as (files, [ds1, ds2]): - for opt in options: - with open_mfdataset(files, data_vars=opt) as ds: - kwargs = dict(data_vars=opt, dim='t') - ds_expect = xr.concat([ds1, ds2], **kwargs) - assert_identical(ds, ds_expect) - - with open_mfdataset(files, coords=opt) as ds: - kwargs = dict(coords=opt, dim='t') - ds_expect = xr.concat([ds1, ds2], **kwargs) - assert_identical(ds, ds_expect) + with open_mfdataset(files, data_vars=opt) as ds: + kwargs = dict(data_vars=opt, dim='t') + ds_expect = xr.concat([ds1, ds2], **kwargs) + assert_identical(ds, ds_expect) + with open_mfdataset(files, coords=opt) as ds: + kwargs = dict(coords=opt, dim='t') + ds_expect = xr.concat([ds1, ds2], **kwargs) + assert_identical(ds, ds_expect) def test_common_coord_when_datavars_all(self): opt = 'all' @@ -2160,12 +2163,10 @@ def test_open_mfdataset(self): original.isel(x=slice(5, 10)).to_netcdf(tmp2) with open_mfdataset([tmp1, tmp2]) as actual: assert isinstance(actual.foo.variable.data, da.Array) - assert actual.foo.variable.data.chunks == \ - ((5, 5),) + assert actual.foo.variable.data.chunks == ((5, 5),) assert_identical(original, actual) with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: - assert actual.foo.variable.data.chunks == \ - ((3, 2, 3, 2),) + assert actual.foo.variable.data.chunks == ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 591c981cd45..3b618f35ea7 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -1,3 +1,5 @@ +import collections +import gc import pickle import threading try: @@ -9,6 +11,7 @@ from xarray.backends.file_manager import CachingFileManager from xarray.backends.lru_cache import LRUCache +from xarray.core.options import set_options @pytest.fixture(params=[1, 2, 3, None]) @@ -38,6 +41,103 @@ def test_file_manager_mock_write(file_cache): lock.__enter__.assert_has_calls([mock.call(), mock.call()]) +@pytest.mark.parametrize('expected_warning', [None, RuntimeWarning]) +def test_file_manager_autoclose(expected_warning): + mock_file = mock.Mock() + opener = mock.Mock(return_value=mock_file) + cache = {} + + manager = CachingFileManager(opener, 'filename', cache=cache) + manager.acquire() + assert cache + + with set_options(warn_for_unclosed_files=expected_warning is not None): + with pytest.warns(expected_warning): + del manager + gc.collect() + + assert not cache + mock_file.close.assert_called_once_with() + + +def test_file_manager_autoclose_while_locked(): + opener = mock.Mock() + lock = threading.Lock() + cache = {} + + manager = CachingFileManager(opener, 'filename', lock=lock, cache=cache) + manager.acquire() + assert cache + + lock.acquire() + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + + # can't clear the cache while locked, but also don't block in __del__ + assert cache + + +def test_file_manager_repr(): + opener = mock.Mock() + manager = CachingFileManager(opener, 'my-file') + assert 'my-file' in repr(manager) + + +def test_file_manager_refcounts(): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + cache = {} + ref_counts = {} + + manager = CachingFileManager( + opener, 'filename', cache=cache, ref_counts=ref_counts) + assert ref_counts[manager._key] == 1 + manager.acquire() + assert cache + + manager2 = CachingFileManager( + opener, 'filename', cache=cache, ref_counts=ref_counts) + assert cache + assert manager._key == manager2._key + assert ref_counts[manager._key] == 2 + + with set_options(warn_for_unclosed_files=False): + del manager + gc.collect() + + assert cache + assert ref_counts[manager2._key] == 1 + mock_file.close.assert_not_called() + + with set_options(warn_for_unclosed_files=False): + del manager2 + gc.collect() + + assert not ref_counts + assert not cache + + +def test_file_manager_replace_object(): + opener = mock.Mock() + cache = {} + ref_counts = {} + + manager = CachingFileManager( + opener, 'filename', cache=cache, ref_counts=ref_counts) + manager.acquire() + assert ref_counts[manager._key] == 1 + assert cache + + manager = CachingFileManager( + opener, 'filename', cache=cache, ref_counts=ref_counts) + assert ref_counts[manager._key] == 1 + assert cache + + manager.close() + + def test_file_manager_write_consecutive(tmpdir, file_cache): path1 = str(tmpdir.join('testing1.txt')) path2 = str(tmpdir.join('testing2.txt'))