Skip to content

Commit 3cd2337

Browse files
rabernatshoyer
authored andcommitted
fix rasterio chunking with s3 datasets (#1817)
* fixes #1816 * new and refactored rasterio tests
1 parent e31cf43 commit 3cd2337

File tree

3 files changed

+87
-213
lines changed

3 files changed

+87
-213
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ Bug fixes
7575
:py:meth:`~Dataset.to_netcdf` (:issue:`1763`).
7676
By `Mike Neish <https://github.com/neishm>`_.
7777

78+
- Fixed chunking with non-file-based rasterio datasets (:issue:`1816`) and
79+
refactored rasterio test suite.
80+
By `Ryan Abernathey <https://github.com/rabernat>`_
7881
- Bug fix in open_dataset(engine='pydap') (:issue:`1775`)
7982
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
8083

xarray/backends/rasterio_.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ def open_rasterio(filename, chunks=None, cache=None, lock=None):
222222
if chunks is not None:
223223
from dask.base import tokenize
224224
# augment the token with the file modification time
225-
mtime = os.path.getmtime(filename)
225+
try:
226+
mtime = os.path.getmtime(filename)
227+
except OSError:
228+
# the filename is probably an s3 bucket rather than a regular file
229+
mtime = None
226230
token = tokenize(filename, mtime, chunks)
227231
name_prefix = 'open_rasterio-%s' % token
228232
if lock is None:

xarray/tests/test_backends.py

Lines changed: 79 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,126 +2149,61 @@ class TestPyNioAutocloseTrue(TestPyNio):
21492149
autoclose = True
21502150

21512151

2152+
@requires_rasterio
2153+
@contextlib.contextmanager
2154+
def create_tmp_geotiff(nx=4, ny=3, nz=3,
2155+
transform_args=[5000, 80000, 1000, 2000.],
2156+
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
2157+
'proj': 'utm', 'zone': 18}):
2158+
# yields a temporary geotiff file and a corresponding expected DataArray
2159+
import rasterio
2160+
from rasterio.transform import from_origin
2161+
with create_tmp_file(suffix='.tif') as tmp_file:
2162+
# allow 2d or 3d shapes
2163+
if nz == 1:
2164+
data_shape = ny, nx
2165+
write_kwargs = {'indexes': 1}
2166+
else:
2167+
data_shape = nz, ny, nx
2168+
write_kwargs = {}
2169+
data = np.arange(nz*ny*nx,
2170+
dtype=rasterio.float32).reshape(*data_shape)
2171+
transform = from_origin(*transform_args)
2172+
with rasterio.open(
2173+
tmp_file, 'w',
2174+
driver='GTiff', height=ny, width=nx, count=nz,
2175+
crs=crs,
2176+
transform=transform,
2177+
dtype=rasterio.float32) as s:
2178+
s.write(data, **write_kwargs)
2179+
dx, dy = s.res[0], -s.res[1]
2180+
2181+
a, b, c, d = transform_args
2182+
data = data[np.newaxis, ...] if nz == 1 else data
2183+
expected = DataArray(data, dims=('band', 'y', 'x'),
2184+
coords={
2185+
'band': np.arange(nz)+1,
2186+
'y': -np.arange(ny) * d + b + dy/2,
2187+
'x': np.arange(nx) * c + a + dx/2,
2188+
})
2189+
yield tmp_file, expected
2190+
2191+
21522192
@requires_rasterio
21532193
class TestRasterio(TestCase):
21542194

21552195
@requires_scipy_or_netCDF4
21562196
def test_serialization(self):
2157-
import rasterio
2158-
from rasterio.transform import from_origin
2159-
2160-
# Create a geotiff file in utm proj
2161-
with create_tmp_file(suffix='.tif') as tmp_file:
2162-
# data
2163-
nx, ny, nz = 4, 3, 3
2164-
data = np.arange(nx * ny * nz,
2165-
dtype=rasterio.float32).reshape(nz, ny, nx)
2166-
transform = from_origin(5000, 80000, 1000, 2000.)
2167-
with rasterio.open(
2168-
tmp_file, 'w',
2169-
driver='GTiff', height=ny, width=nx, count=nz,
2170-
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
2171-
'proj': 'utm', 'zone': 18},
2172-
transform=transform,
2173-
dtype=rasterio.float32) as s:
2174-
s.write(data)
2175-
2197+
with create_tmp_geotiff() as (tmp_file, expected):
21762198
# Write it to a netcdf and read again (roundtrip)
21772199
with xr.open_rasterio(tmp_file) as rioda:
21782200
with create_tmp_file(suffix='.nc') as tmp_nc_file:
21792201
rioda.to_netcdf(tmp_nc_file)
21802202
with xr.open_dataarray(tmp_nc_file) as ncds:
21812203
assert_identical(rioda, ncds)
21822204

2183-
@requires_scipy_or_netCDF4
2184-
def test_nodata(self):
2185-
import rasterio
2186-
from rasterio.transform import from_origin
2187-
2188-
# Create a geotiff file in utm proj
2189-
with create_tmp_file(suffix='.tif') as tmp_file:
2190-
# data
2191-
nx, ny, nz = 4, 3, 3
2192-
data = np.arange(nx*ny*nz,
2193-
dtype=rasterio.float32).reshape(nz, ny, nx)
2194-
transform = from_origin(5000, 80000, 1000, 2000.)
2195-
with rasterio.open(
2196-
tmp_file, 'w',
2197-
driver='GTiff', height=ny, width=nx, count=nz,
2198-
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
2199-
'proj': 'utm', 'zone': 18},
2200-
transform=transform,
2201-
nodata=-9765,
2202-
dtype=rasterio.float32) as s:
2203-
s.write(data)
2204-
expected_nodatavals = [-9765, -9765, -9765]
2205-
with xr.open_rasterio(tmp_file) as rioda:
2206-
np.testing.assert_array_equal(rioda.attrs['nodatavals'],
2207-
expected_nodatavals)
2208-
with create_tmp_file(suffix='.nc') as tmp_nc_file:
2209-
rioda.to_netcdf(tmp_nc_file)
2210-
with xr.open_dataarray(tmp_nc_file) as ncds:
2211-
np.testing.assert_array_equal(ncds.attrs['nodatavals'],
2212-
expected_nodatavals)
2213-
2214-
@requires_scipy_or_netCDF4
2215-
def test_nodata_missing(self):
2216-
import rasterio
2217-
from rasterio.transform import from_origin
2218-
2219-
# Create a geotiff file in utm proj
2220-
with create_tmp_file(suffix='.tif') as tmp_file:
2221-
# data
2222-
nx, ny, nz = 4, 3, 3
2223-
data = np.arange(nx*ny*nz,
2224-
dtype=rasterio.float32).reshape(nz, ny, nx)
2225-
transform = from_origin(5000, 80000, 1000, 2000.)
2226-
with rasterio.open(
2227-
tmp_file, 'w',
2228-
driver='GTiff', height=ny, width=nx, count=nz,
2229-
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
2230-
'proj': 'utm', 'zone': 18},
2231-
transform=transform,
2232-
dtype=rasterio.float32) as s:
2233-
s.write(data)
2234-
2235-
expected_nodatavals = [np.nan, np.nan, np.nan]
2236-
with xr.open_rasterio(tmp_file) as rioda:
2237-
np.testing.assert_array_equal(rioda.attrs['nodatavals'],
2238-
expected_nodatavals)
2239-
with create_tmp_file(suffix='.nc') as tmp_nc_file:
2240-
rioda.to_netcdf(tmp_nc_file)
2241-
with xr.open_dataarray(tmp_nc_file) as ncds:
2242-
np.testing.assert_array_equal(ncds.attrs['nodatavals'],
2243-
expected_nodatavals)
2244-
22452205
def test_utm(self):
2246-
import rasterio
2247-
from rasterio.transform import from_origin
2248-
2249-
# Create a geotiff file in utm proj
2250-
with create_tmp_file(suffix='.tif') as tmp_file:
2251-
# data
2252-
nx, ny, nz = 4, 3, 3
2253-
data = np.arange(nx * ny * nz,
2254-
dtype=rasterio.float32).reshape(nz, ny, nx)
2255-
transform = from_origin(5000, 80000, 1000, 2000.)
2256-
with rasterio.open(
2257-
tmp_file, 'w',
2258-
driver='GTiff', height=ny, width=nx, count=nz,
2259-
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
2260-
'proj': 'utm', 'zone': 18},
2261-
transform=transform,
2262-
dtype=rasterio.float32) as s:
2263-
s.write(data)
2264-
dx, dy = s.res[0], -s.res[1]
2265-
2266-
# Tests
2267-
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
2268-
'band': [1, 2, 3],
2269-
'y': -np.arange(ny) * 2000 + 80000 + dy / 2,
2270-
'x': np.arange(nx) * 1000 + 5000 + dx / 2,
2271-
})
2206+
with create_tmp_geotiff() as (tmp_file, expected):
22722207
with xr.open_rasterio(tmp_file) as rioda:
22732208
assert_allclose(rioda, expected)
22742209
assert 'crs' in rioda.attrs
@@ -2281,32 +2216,9 @@ def test_utm(self):
22812216
assert isinstance(rioda.attrs['transform'], tuple)
22822217

22832218
def test_platecarree(self):
2284-
2285-
import rasterio
2286-
from rasterio.transform import from_origin
2287-
2288-
# Create a geotiff file in latlong proj
2289-
with create_tmp_file(suffix='.tif') as tmp_file:
2290-
# data
2291-
nx, ny = 8, 10
2292-
data = np.arange(80, dtype=rasterio.float32).reshape(ny, nx)
2293-
transform = from_origin(1, 2, 0.5, 2.)
2294-
with rasterio.open(
2295-
tmp_file, 'w',
2296-
driver='GTiff', height=ny, width=nx, count=1,
2297-
crs='+proj=latlong',
2298-
transform=transform,
2299-
dtype=rasterio.float32) as s:
2300-
s.write(data, indexes=1)
2301-
dx, dy = s.res[0], -s.res[1]
2302-
2303-
# Tests
2304-
expected = DataArray(data[np.newaxis, ...],
2305-
dims=('band', 'y', 'x'),
2306-
coords={'band': [1],
2307-
'y': -np.arange(ny) * 2 + 2 + dy / 2,
2308-
'x': np.arange(nx) * 0.5 + 1 + dx / 2,
2309-
})
2219+
with create_tmp_geotiff(8, 10, 1, transform_args=[1, 2, 0.5, 2.],
2220+
crs='+proj=latlong') \
2221+
as (tmp_file, expected):
23102222
with xr.open_rasterio(tmp_file) as rioda:
23112223
assert_allclose(rioda, expected)
23122224
assert 'crs' in rioda.attrs
@@ -2319,32 +2231,8 @@ def test_platecarree(self):
23192231
assert isinstance(rioda.attrs['transform'], tuple)
23202232

23212233
def test_indexing(self):
2322-
2323-
import rasterio
2324-
from rasterio.transform import from_origin
2325-
2326-
# Create a geotiff file in latlong proj
2327-
with create_tmp_file(suffix='.tif') as tmp_file:
2328-
# data
2329-
nx, ny, nz = 8, 10, 3
2330-
data = np.arange(nx * ny * nz,
2331-
dtype=rasterio.float32).reshape(nz, ny, nx)
2332-
transform = from_origin(1, 2, 0.5, 2.)
2333-
with rasterio.open(
2334-
tmp_file, 'w',
2335-
driver='GTiff', height=ny, width=nx, count=nz,
2336-
crs='+proj=latlong',
2337-
transform=transform,
2338-
dtype=rasterio.float32) as s:
2339-
s.write(data)
2340-
dx, dy = s.res[0], -s.res[1]
2341-
2342-
# ref
2343-
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
2344-
'x': (np.arange(nx) * 0.5 + 1) + dx / 2,
2345-
'y': (-np.arange(ny) * 2 + 2) + dy / 2,
2346-
'band': [1, 2, 3]})
2347-
2234+
with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
2235+
crs='+proj=latlong') as (tmp_file, expected):
23482236
with xr.open_rasterio(tmp_file, cache=False) as actual:
23492237

23502238
# tests
@@ -2411,33 +2299,8 @@ def test_indexing(self):
24112299
assert_allclose(ac, ex)
24122300

24132301
def test_caching(self):
2414-
2415-
import rasterio
2416-
from rasterio.transform import from_origin
2417-
2418-
# Create a geotiff file in latlong proj
2419-
with create_tmp_file(suffix='.tif') as tmp_file:
2420-
# data
2421-
nx, ny, nz = 8, 10, 3
2422-
data = np.arange(nx * ny * nz,
2423-
dtype=rasterio.float32).reshape(nz, ny, nx)
2424-
transform = from_origin(1, 2, 0.5, 2.)
2425-
with rasterio.open(
2426-
tmp_file, 'w',
2427-
driver='GTiff', height=ny, width=nx, count=nz,
2428-
crs='+proj=latlong',
2429-
transform=transform,
2430-
dtype=rasterio.float32) as s:
2431-
s.write(data)
2432-
dx, dy = s.res[0], -s.res[1]
2433-
2434-
# ref
2435-
expected = DataArray(
2436-
data, dims=('band', 'y', 'x'), coords={
2437-
'x': (np.arange(nx) * 0.5 + 1) + dx / 2,
2438-
'y': (-np.arange(ny) * 2 + 2) + dy / 2,
2439-
'band': [1, 2, 3]})
2440-
2302+
with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
2303+
crs='+proj=latlong') as (tmp_file, expected):
24412304
# Cache is the default
24422305
with xr.open_rasterio(tmp_file) as actual:
24432306

@@ -2456,39 +2319,15 @@ def test_caching(self):
24562319

24572320
@requires_dask
24582321
def test_chunks(self):
2459-
2460-
import rasterio
2461-
from rasterio.transform import from_origin
2462-
2463-
# Create a geotiff file in latlong proj
2464-
with create_tmp_file(suffix='.tif') as tmp_file:
2465-
# data
2466-
nx, ny, nz = 8, 10, 3
2467-
data = np.arange(nx * ny * nz,
2468-
dtype=rasterio.float32).reshape(nz, ny, nx)
2469-
transform = from_origin(1, 2, 0.5, 2.)
2470-
with rasterio.open(
2471-
tmp_file, 'w',
2472-
driver='GTiff', height=ny, width=nx, count=nz,
2473-
crs='+proj=latlong',
2474-
transform=transform,
2475-
dtype=rasterio.float32) as s:
2476-
s.write(data)
2477-
dx, dy = s.res[0], -s.res[1]
2478-
2322+
with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
2323+
crs='+proj=latlong') as (tmp_file, expected):
24792324
# Chunk at open time
24802325
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:
24812326

24822327
import dask.array as da
24832328
self.assertIsInstance(actual.data, da.Array)
24842329
assert 'open_rasterio' in actual.data.name
24852330

2486-
# ref
2487-
expected = DataArray(data, dims=('band', 'y', 'x'), coords={
2488-
'x': np.arange(nx) * 0.5 + 1 + dx / 2,
2489-
'y': -np.arange(ny) * 2 + 2 + dy / 2,
2490-
'band': [1, 2, 3]})
2491-
24922331
# do some arithmetic
24932332
ac = actual.mean()
24942333
ex = expected.mean()
@@ -2503,6 +2342,7 @@ def test_ENVI_tags(self):
25032342
from rasterio.transform import from_origin
25042343

25052344
# Create an ENVI file with some tags in the ENVI namespace
2345+
# this test uses a custom driver, so we can't use create_tmp_geotiff
25062346
with create_tmp_file(suffix='.dat') as tmp_file:
25072347
# data
25082348
nx, ny, nz = 4, 3, 3
@@ -2545,6 +2385,33 @@ def test_ENVI_tags(self):
25452385
assert isinstance(rioda.attrs['map_info'], basestring)
25462386
assert isinstance(rioda.attrs['samples'], basestring)
25472387

2388+
def test_no_mftime(self):
2389+
# rasterio can accept "filename" urguments that are actually urls,
2390+
# including paths to remote files.
2391+
# In issue #1816, we found that these caused dask to break, because
2392+
# the modification time was used to determine the dask token. This
2393+
# tests ensure we can still chunk such files when reading with
2394+
# rasterio.
2395+
with create_tmp_geotiff(8, 10, 3, transform_args=[1, 2, 0.5, 2.],
2396+
crs='+proj=latlong') as (tmp_file, expected):
2397+
with mock.patch('os.path.getmtime', side_effect=OSError):
2398+
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:
2399+
import dask.array as da
2400+
self.assertIsInstance(actual.data, da.Array)
2401+
assert_allclose(actual, expected)
2402+
2403+
@network
2404+
def test_http_url(self):
2405+
# more examples urls here
2406+
# http://download.osgeo.org/geotiff/samples/
2407+
url = 'http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif'
2408+
with xr.open_rasterio(url) as actual:
2409+
assert actual.shape == (1, 512, 512)
2410+
# make sure chunking works
2411+
with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual:
2412+
import dask.array as da
2413+
self.assertIsInstance(actual.data, da.Array)
2414+
25482415

25492416
class TestEncodingInvalid(TestCase):
25502417

0 commit comments

Comments
 (0)