Skip to content

Commit cc80626

Browse files
committed
some progress with indexing
1 parent 1472aaf commit cc80626

File tree

3 files changed

+201
-76
lines changed

3 files changed

+201
-76
lines changed

xarray/backends/rasterio_.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,64 @@
1515

1616
_rio_varname = 'raster'
1717

18+
_error_mess = 'The kind of indexing operation you are trying to do is not ' \
19+
'valid on RasterIO files. Try to load your data with ds.load()' \
20+
' first'
1821

1922
class RasterioArrayWrapper(NDArrayMixin):
2023
def __init__(self, ds):
21-
self._ds = ds
22-
self.array = ds.read()
24+
self.ds = ds
25+
self._shape = self.ds.count, self.ds.height, self.ds.width
2326

2427
@property
2528
def dtype(self):
26-
return np.dtype(self._ds.dtypes[0])
29+
return np.dtype(self.ds.dtypes[0]) \
30+
31+
@property
32+
def shape(self):
33+
return self._shape
2734

2835
def __getitem__(self, key):
29-
if key == () and self.ndim == 0:
30-
return self.array.get_value()
31-
return self.array[key]
36+
37+
# make our job a bit easier
38+
key = indexing.canonicalize_indexer(key, len(self.shape))
39+
40+
# bands cannot be windowed but they can be listed
41+
bands, n = key[0], self.shape[0]
42+
if isinstance(bands, slice):
43+
start = bands.start if bands.start is not None else 0
44+
stop = bands.stop if bands.stop is not None else n
45+
if bands.step is not None and bands.step != 1:
46+
raise IndexError(_error_mess)
47+
bands = np.arange(start, stop)
48+
# be sure we give out a list
49+
bands = (np.asarray(bands) + 1).tolist()
50+
51+
# but other dims can
52+
window = []
53+
for k, n in zip(key[1:], self.shape[1:]):
54+
if isinstance(k, slice):
55+
start = k.start if k.start is not None else 0
56+
stop = k.stop if k.stop is not None else n
57+
if k.step is not None and k.step != 1:
58+
raise IndexError(_error_mess)
59+
else:
60+
k = np.asarray(k).flatten()
61+
start = k[0]
62+
stop = k[-1] + 1
63+
if (stop - start) != len(k):
64+
raise IndexError(_error_mess)
65+
window.append((start, stop))
66+
67+
return self.ds.read(bands, window=window)
3268

3369

3470
class RasterioDataStore(AbstractDataStore):
3571
"""Store for accessing datasets via Rasterio
3672
"""
3773
def __init__(self, filename, mode='r'):
3874

39-
# TODO: is the rasterio.Env() really necessary, and if yes where?
75+
# TODO: is the rasterio.Env() really necessary, and if yes: when?
4076
with rasterio.Env():
4177
self.ds = rasterio.open(filename, mode=mode)
4278

@@ -45,37 +81,37 @@ def __init__(self, filename, mode='r'):
4581
dx, dy = self.ds.res[0], -self.ds.res[1]
4682
x0 = self.ds.bounds.right if dx < 0 else self.ds.bounds.left
4783
y0 = self.ds.bounds.top if dy < 0 else self.ds.bounds.bottom
48-
y = np.linspace(start=y0, num=ny, stop=(y0 + (ny-1) * dy))
49-
x = np.linspace(start=x0, num=nx, stop=(x0 + (nx-1) * dx))
84+
x = np.linspace(start=x0, num=nx, stop=(x0 + (nx - 1) * dx))
85+
y = np.linspace(start=y0, num=ny, stop=(y0 + (ny - 1) * dy))
5086

51-
self.coords = OrderedDict()
52-
self.coords['y'] = Variable(('y', ), y)
53-
self.coords['x'] = Variable(('x', ), x)
87+
self._vars = OrderedDict()
88+
self._vars['y'] = Variable(('y',), y)
89+
self._vars['x'] = Variable(('x',), x)
5490

5591
# Get dims
5692
if self.ds.count >= 1:
5793
self.dims = ('band', 'y', 'x')
58-
self.coords['band'] = Variable(('band', ),
59-
np.atleast_1d(self.ds.indexes))
94+
self._vars['band'] = Variable(('band',),
95+
np.atleast_1d(self.ds.indexes))
6096
else:
61-
raise ValueError('unknown dims')
97+
raise ValueError('Unknown dims')
6298

6399
self._attrs = OrderedDict()
64100
with suppress(AttributeError):
65-
for attr_name in ['crs', 'transform', 'proj']:
101+
for attr_name in ['crs']:
66102
self._attrs[attr_name] = getattr(self.ds, attr_name)
67103

104+
# Get data
105+
self._vars[_rio_varname] = self.open_store_variable(_rio_varname)
106+
68107
def open_store_variable(self, var):
69108
if var != _rio_varname:
70109
raise ValueError('Rasterio variables are named %s' % _rio_varname)
71110
data = indexing.LazilyIndexedArray(RasterioArrayWrapper(self.ds))
72111
return Variable(self.dims, data, self._attrs)
73112

74113
def get_variables(self):
75-
# Get lat lon coordinates
76-
vars = _try_to_get_latlon_coords(self.coords, self._attrs)
77-
vars[_rio_varname] = self.open_store_variable(_rio_varname)
78-
return FrozenOrderedDict(vars)
114+
return FrozenOrderedDict(self._vars)
79115

80116
def get_attrs(self):
81117
return Frozen(self._attrs)
@@ -85,29 +121,3 @@ def get_dimensions(self):
85121

86122
def close(self):
87123
self.ds.close()
88-
89-
90-
def _try_to_get_latlon_coords(coords, attrs):
91-
from rasterio.warp import transform
92-
93-
if 'crs' in attrs:
94-
proj = attrs['crs']
95-
# TODO: if the proj is already PlateCarree, making 2D coordinates
96-
# is not the best thing to do here.
97-
ny, nx = len(coords['y']), len(coords['x'])
98-
x, y = np.meshgrid(coords['x'], coords['y'])
99-
# Rasterio works with 1D arrays
100-
xc, yc = transform(proj, {'init': 'EPSG:4326'},
101-
x.flatten(), y.flatten())
102-
xc = np.asarray(xc).reshape((ny, nx))
103-
yc = np.asarray(yc).reshape((ny, nx))
104-
dims = ('y', 'x')
105-
coords['lon'] = Variable(dims, xc,
106-
attrs={'units': 'degrees_east',
107-
'long_name': 'longitude',
108-
'standard_name': 'longitude'})
109-
coords['lat'] = Variable(dims, yc,
110-
attrs={'units': 'degrees_north',
111-
'long_name': 'latitude',
112-
'standard_name': 'latitude'})
113-
return coords

xarray/core/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,3 +487,38 @@ def ensure_us_time_resolution(val):
487487
elif np.issubdtype(val.dtype, np.timedelta64):
488488
val = val.astype('timedelta64[us]')
489489
return val
490+
491+
492+
def get_latlon_coords_from_crs(ds, crs=None):
493+
"""Currently very specific function, but coul dbe generalized."""
494+
495+
from .. import DataArray
496+
497+
try:
498+
from rasterio.warp import transform
499+
except ImportError:
500+
raise ImportError('add_latlon_coords_from_crs needs RasterIO.')
501+
502+
if crs is None:
503+
if 'crs' in ds.attrs:
504+
crs = ds.attrs['crs']
505+
else:
506+
raise ValueError('crs not found')
507+
508+
ny, nx = len(ds['y']), len(ds['x'])
509+
x, y = np.meshgrid(ds['x'], ds['y'])
510+
# Rasterio works with 1D arrays
511+
xc, yc = transform(crs, {'init': 'EPSG:4326'},
512+
x.flatten(), y.flatten())
513+
xc = np.asarray(xc).reshape((ny, nx))
514+
yc = np.asarray(yc).reshape((ny, nx))
515+
dims = ('y', 'x')
516+
ds['lon'] = DataArray(xc, dims=dims,
517+
attrs={'units': 'degrees_east',
518+
'long_name': 'longitude',
519+
'standard_name': 'longitude'})
520+
ds['lat'] = DataArray(yc, dims=dims,
521+
attrs={'units': 'degrees_north',
522+
'long_name': 'latitude',
523+
'standard_name': 'latitude'})
524+
return ds

xarray/tests/test_backends.py

Lines changed: 110 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,21 +1216,6 @@ def test_weakrefs(self):
12161216
@requires_rasterio
12171217
class TestRasterIO(CFEncodedDataTest, Only32BitTypes, TestCase):
12181218

1219-
# def setUp(self):
1220-
#
1221-
#
1222-
# name = 'test_latlong.tif'
1223-
#
1224-
#
1225-
# name = 'test_utm.tif'
1226-
# transform = from_origin(-300, 200, 1000, 1000)
1227-
# with rasterio.open(
1228-
# name, 'w',
1229-
# driver='GTiff', height=3, width=4, count=1,
1230-
# transform=transform,
1231-
# dtype=rasterio.float32) as s:
1232-
# s.write(a, indexes=1)
1233-
12341219
def test_write_store(self):
12351220
# RasterIO is read-only for now
12361221
pass
@@ -1239,71 +1224,166 @@ def test_orthogonal_indexing(self):
12391224
# RasterIO also does not support list-like indexing
12401225
pass
12411226

1242-
def test_latlong_coords(self):
1227+
def test_latlong_basics(self):
12431228

12441229
import rasterio
12451230
from rasterio.transform import from_origin
1231+
from ..core.utils import get_latlon_coords_from_crs
12461232

12471233
# Create a geotiff file in latlong proj
1248-
data = np.arange(12, dtype=rasterio.float32).reshape(3, 4)
12491234
with create_tmp_file(suffix='.tif') as tmp_file:
1250-
transform = from_origin(1, 2, 0.5, 1.)
1235+
# data
1236+
nx, ny = 8, 10
1237+
data = np.arange(80, dtype=rasterio.float32).reshape(ny, nx)
1238+
transform = from_origin(1, 2, 0.5, 2.)
12511239
with rasterio.open(
12521240
tmp_file, 'w',
1253-
driver='GTiff', height=3, width=4, count=1,
1241+
driver='GTiff', height=ny, width=nx, count=1,
12541242
crs='+proj=latlong',
12551243
transform=transform,
12561244
dtype=rasterio.float32) as s:
12571245
s.write(data, indexes=1)
12581246
actual = xr.open_dataset(tmp_file, engine='rasterio')
12591247

1248+
# ref
12601249
expected = Dataset()
1261-
expected['x'] = ('x', [1, 1.5, 2, 2.5])
1262-
expected['y'] = ('y', [2., 1, 0])
1250+
expected['x'] = ('x', np.arange(nx)*0.5 + 1)
1251+
expected['y'] = ('y', -np.arange(ny)*2 + 2)
12631252
expected['band'] = ('band', [1])
12641253
expected['raster'] = (('band', 'y', 'x'), data[np.newaxis, ...])
12651254
lon, lat = np.meshgrid(expected['x'], expected['y'])
12661255
expected['lon'] = (('y', 'x'), lon)
12671256
expected['lat'] = (('y', 'x'), lat)
12681257

1258+
# tests
12691259
assert_allclose(actual.y, expected.y)
12701260
assert_allclose(actual.x, expected.x)
12711261
assert_allclose(actual.raster, expected.raster)
1262+
1263+
actual = get_latlon_coords_from_crs(actual)
12721264
assert_allclose(actual.lon, expected.lon)
12731265
assert_allclose(actual.lat, expected.lat)
12741266

1275-
print(actual)
1267+
assert 'crs' in actual.attrs
12761268

1277-
def test_utm_coords(self):
1269+
def test_utm_basics(self):
12781270

12791271
import rasterio
12801272
from rasterio.transform import from_origin
1273+
from ..core.utils import get_latlon_coords_from_crs
12811274

12821275
# Create a geotiff file in utm proj
1283-
data = np.arange(24, dtype=rasterio.float32).reshape(2, 3, 4)
12841276
with create_tmp_file(suffix='.tif') as tmp_file:
1285-
transform = from_origin(-3000, 1000, 1000, 1000)
1277+
# data
1278+
nx, ny, nz = 4, 3, 3
1279+
data = np.arange(nx*ny*nz,
1280+
dtype=rasterio.float32).reshape(nz, ny, nx)
1281+
transform = from_origin(5000, 80000, 1000, 2000.)
12861282
with rasterio.open(
12871283
tmp_file, 'w',
1288-
driver='GTiff', height=3, width=4, count=2,
1284+
driver='GTiff', height=ny, width=nx, count=nz,
12891285
crs={'units': 'm', 'no_defs': True, 'ellps': 'WGS84',
12901286
'proj': 'utm', 'zone': 18},
12911287
transform=transform,
12921288
dtype=rasterio.float32) as s:
12931289
s.write(data)
12941290
actual = xr.open_dataset(tmp_file, engine='rasterio')
12951291

1292+
# ref
12961293
expected = Dataset()
1297-
expected['x'] = ('x', [-3000., -2000, -1000, 0])
1298-
expected['y'] = ('y', [1000., 0, -1000])
1299-
expected['band'] = ('band', [1, 2])
1294+
expected['x'] = ('x', np.arange(nx)*1000 + 5000)
1295+
expected['y'] = ('y', -np.arange(ny)*2000 + 80000)
1296+
expected['band'] = ('band', [1, 2, 3])
13001297
expected['raster'] = (('band', 'y', 'x'), data)
13011298

1299+
# data obtained independently with pyproj
1300+
lon = np.array(
1301+
[[-79.44429834, -79.43533803, -79.42637762, -79.4174171],
1302+
[-79.44428102, -79.43532075, -79.42636037, -79.41739988],
1303+
[-79.44426413, -79.4353039, -79.42634355, -79.4173831]])
1304+
lat = np.array(
1305+
[[0.72159393, 0.72160275, 0.72161156, 0.72162034],
1306+
[0.70355411, 0.70356271, 0.70357129, 0.70357986],
1307+
[0.68551428, 0.68552266, 0.68553103, 0.68553937]])
1308+
expected['lon'] = (('y', 'x'), lon)
1309+
expected['lat'] = (('y', 'x'), lat)
1310+
1311+
# tests
13021312
assert_allclose(actual.y, expected.y)
13031313
assert_allclose(actual.x, expected.x)
13041314
assert_allclose(actual.raster, expected.raster)
13051315

1306-
print(actual)
1316+
actual = get_latlon_coords_from_crs(actual)
1317+
assert_allclose(actual.lon, expected.lon)
1318+
assert_allclose(actual.lat, expected.lat)
1319+
1320+
assert 'crs' in actual.attrs
1321+
1322+
def test_indexing(self):
1323+
1324+
import rasterio
1325+
from rasterio.transform import from_origin
1326+
1327+
# Create a geotiff file in latlong proj
1328+
with create_tmp_file(suffix='.tif') as tmp_file:
1329+
# data
1330+
nx, ny, nz = 8, 10, 3
1331+
data = np.arange(nx*ny*nz,
1332+
dtype=rasterio.float32).reshape(nz, ny, nx)
1333+
transform = from_origin(1, 2, 0.5, 2.)
1334+
with rasterio.open(
1335+
tmp_file, 'w',
1336+
driver='GTiff', height=ny, width=nx, count=nz,
1337+
crs='+proj=latlong',
1338+
transform=transform,
1339+
dtype=rasterio.float32) as s:
1340+
s.write(data)
1341+
actual = xr.open_dataset(tmp_file, engine='rasterio')
1342+
1343+
# ref
1344+
expected = Dataset()
1345+
expected['x'] = ('x', np.arange(nx)*0.5 + 1)
1346+
expected['y'] = ('y', -np.arange(ny)*2 + 2)
1347+
expected['band'] = ('band', [1, 2, 3])
1348+
expected['raster'] = (('band', 'y', 'x'), data)
1349+
1350+
# tests
1351+
_ex = expected.isel(band=1)
1352+
_ac = actual.isel(band=1)
1353+
assert_allclose(_ac.y, _ex.y)
1354+
assert_allclose(_ac.x, _ex.x)
1355+
assert_allclose(_ac.band, _ex.band)
1356+
assert_allclose(_ac.raster, _ex.raster)
1357+
1358+
_ex = expected.isel(x=slice(2, 5), y=slice(5, 7))
1359+
_ac = actual.isel(x=slice(2, 5), y=slice(5, 7))
1360+
assert_allclose(_ac.y, _ex.y)
1361+
assert_allclose(_ac.x, _ex.x)
1362+
assert_allclose(_ac.raster, _ex.raster)
1363+
1364+
_ex = expected.isel(band=slice(1, 2), x=slice(2, 5), y=slice(5, 7))
1365+
_ac = actual.isel(band=slice(1, 2), x=slice(2, 5), y=slice(5, 7))
1366+
assert_allclose(_ac.y, _ex.y)
1367+
assert_allclose(_ac.x, _ex.x)
1368+
assert_allclose(_ac.raster, _ex.raster)
1369+
1370+
_ex = expected.isel(x=1, y=2)
1371+
_ac = actual.isel(x=1, y=2)
1372+
assert_allclose(_ac.y, _ex.y)
1373+
assert_allclose(_ac.x, _ex.x)
1374+
# TODO: this doesnt work properly because of the shape
1375+
# assert_allclose(_ac.raster, _ex.raster)
1376+
np.testing.assert_allclose(_ac.raster.values.flatten(),
1377+
_ex.raster.values)
1378+
1379+
_ex = expected.isel(band=0, x=1, y=2)
1380+
_ac = actual.isel(band=0, x=1, y=2)
1381+
assert_allclose(_ac.y, _ex.y)
1382+
assert_allclose(_ac.x, _ex.x)
1383+
# TODO: this doesnt work properly because of the shape
1384+
# assert_allclose(_ac.raster, _ex.raster)
1385+
np.testing.assert_allclose(_ac.raster.values.flatten(),
1386+
_ex.raster.values)
13071387

13081388

13091389
class TestEncodingInvalid(TestCase):

0 commit comments

Comments
 (0)