Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit ef65c33

Browse files
committed
Speed up loading by only using HRV channel in dataset for now. #23
1 parent c6a265f commit ef65c33

File tree

4 files changed

+49
-33
lines changed

4 files changed

+49
-33
lines changed

notebooks/benchmark_loading_speed.ipynb

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525
"import pytorch_lightning as pl"
2626
]
2727
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"id": "b0c3c31a-86ed-493b-ab00-625fa4edb302",
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"FILENAME = 'gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_just_hrv.zarr'"
36+
]
37+
},
2838
{
2939
"cell_type": "code",
3040
"execution_count": null,
@@ -45,7 +55,8 @@
4555
"outputs": [],
4656
"source": [
4757
"sat_data_source = data_sources.SatelliteDataSource(\n",
48-
" #filename='gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep_quarter_geospatial.zarr',\n",
58+
" filename=FILENAME,\n",
59+
" consolidated=False,\n",
4960
" image_size_pixels=128,\n",
5061
" history_len=HISTORY_LEN,\n",
5162
" forecast_len=FORECAST_LEN\n",
@@ -76,16 +87,6 @@
7687
"len(t0_datetimes)"
7788
]
7889
},
79-
{
80-
"cell_type": "code",
81-
"execution_count": null,
82-
"id": "da2a60b4-f28f-4db2-99d4-879a321b905c",
83-
"metadata": {},
84-
"outputs": [],
85-
"source": [
86-
"t0_datetimes[:5_000]"
87-
]
88-
},
8990
{
9091
"cell_type": "code",
9192
"execution_count": null,
@@ -97,7 +98,7 @@
9798
" batch_size=32,\n",
9899
" n_samples_per_timestep=4,\n",
99100
" data_sources=[sat_data_source],\n",
100-
" t0_datetimes=t0_datetimes[:5_000])"
101+
" t0_datetimes=t0_datetimes)"
101102
]
102103
},
103104
{
@@ -178,7 +179,7 @@
178179
"\n",
179180
" \n",
180181
" def forward(self, x):\n",
181-
" images = x['sat_data'][:, self.history_len:, :, :, 0]\n",
182+
" images = x['sat_data'][:, self.history_len:, :, :] # , 0]\n",
182183
" images = normalise_images_in_model(images, self.device)\n",
183184
" \n",
184185
" # Pass data through the network :)\n",
@@ -251,7 +252,15 @@
251252
"execution_count": null,
252253
"id": "3eb05006-b2df-426c-a775-e41402abf7b0",
253254
"metadata": {},
254-
"outputs": [],
255+
"outputs": [
256+
{
257+
"name": "stdout",
258+
"output_type": "stream",
259+
"text": [
260+
"Epoch 0: : 609it [00:39, 15.23it/s, loss=0.259, v_num=47]"
261+
]
262+
}
263+
],
255264
"source": [
256265
"trainer.fit(model, train_dataloader=dataloader)"
257266
]

nowcasting_dataset/data_sources/satellite_data_source.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ class SatelliteDataSource(DataSource):
2222
y is top-to-bottom.
2323
Access using public sat_data property.
2424
filename: Filename of the satellite data Zarr.
25-
channels: List of satellite channels to load.
26-
image_size: Instance of Square, which defines the size of each sample.
27-
(Inherited from DataSource super-class).
25+
consolidated: Whether or not the Zarr store is consolidated.
26+
channels: List of satellite channels to load. If None then don't filter by channels.
27+
image_size_pixels: Size of the width and height of the image crop returned by get_sample().
2828
"""
2929
filename: Union[str, Path] = consts.SAT_FILENAME
30-
channels: Iterable[str] = ('HRV', )
30+
consolidated: bool = True
31+
channels: Optional[Iterable[str]] = None
3132
image_size_pixels: InitVar[int] = 128
3233
meters_per_pixel: InitVar[int] = 2_000
3334

@@ -50,8 +51,9 @@ def open(self) -> None:
5051
# If we did that, then we couldn't copy SatelliteDataSource
5152
# instances into separate processes. Instead,
5253
# call open() _after_ creating separate processes.
53-
sat_data = self._open_sat_data()
54-
self._sat_data = sat_data.sel(variable=list(self.channels))
54+
self._sat_data = self._open_sat_data()
55+
if self.channels is not None:
56+
self._sat_data = self._sat_data.sel(variable=list(self.channels))
5557

5658
def get_sample(
5759
self,
@@ -98,7 +100,7 @@ def geospatial_border(self) -> List[Tuple[Number, Number]]:
98100
[GEO_BORDER, -GEO_BORDER])]
99101

100102
def _open_sat_data(self):
101-
return open_sat_data(filename=self.filename)
103+
return open_sat_data(filename=self.filename, consolidated=self.consolidated)
102104

103105

104106
def open_sat_data(

nowcasting_dataset/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import numpy as np
22
import pandas as pd
33
from nowcasting_dataset.consts import Array
4+
import fsspec.asyn
45

56

67
def set_fsspec_for_multiprocess() -> None:
78
"""Clear reference to the loop and thread. This is necessary otherwise
89
gcsfs hangs in the ML training loop. Only required for fsspec >= 0.9.0
910
See https://github.com/dask/gcsfs/issues/379#issuecomment-839929801
1011
TODO: Try deleting this two lines to make sure this is still relevant."""
11-
import fsspec
12-
try:
13-
fsspec.asyn.iothread[0] = None
14-
fsspec.asyn.loop[0] = None
15-
except AttributeError:
16-
pass
12+
fsspec.asyn.iothread[0] = None
13+
fsspec.asyn.loop[0] = None
1714

1815

1916
def is_monotonically_increasing(a: Array) -> bool:

scripts/rechunk_sat_data.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,38 @@
55
import numcodecs
66
import gcsfs
77
import rechunker
8-
from dask.diagnostics import ProgressBar
8+
import zarr
99

1010

1111
BUCKET = Path('solar-pv-nowcasting-data')
1212
SAT_PATH = BUCKET / 'satellite/EUMETSAT/SEVIRI_RSS/OSGB36/'
1313
SOURCE_SAT_FILENAME = 'gs://' + str(SAT_PATH / 'all_zarr_int16')
14-
TARGET_SAT_FILENAME = SAT_PATH / 'all_zarr_int16_single_timestep_quarter_geospatial.zarr'
14+
TARGET_SAT_FILENAME = SAT_PATH / 'all_zarr_int16_single_timestep_just_hrv.zarr'
1515
TEMP_STORE_FILENAME = SAT_PATH / 'temp.zarr'
1616

1717

1818
def main():
1919
source_sat_dataset = xr.open_zarr(SOURCE_SAT_FILENAME, consolidated=True)
20-
20+
#source_sat_dataset = source_sat_dataset.isel(time=slice(0, 3600))
21+
source_sat_dataset = source_sat_dataset.sel(variable='HRV')
22+
2123
gcs = gcsfs.GCSFileSystem()
2224
target_store = gcs.get_mapper(TARGET_SAT_FILENAME)
2325
temp_store = gcs.get_mapper(TEMP_STORE_FILENAME)
2426

2527
target_chunks = {
2628
'stacked_eumetsat_data': {
2729
"time": 1,
28-
"y": 704 // 2,
29-
"x": 548 // 2,
30-
"variable": 1}}
30+
"y": 704,
31+
"x": 548,
32+
#"variable": 1
33+
}}
3134

3235
encoding = {
3336
'stacked_eumetsat_data': {
3437
'compressor': numcodecs.Blosc(cname="zstd", clevel=5)}}
3538

39+
print('Rechunking...')
3640
rechunk_plan = rechunker.rechunk(
3741
source=source_sat_dataset,
3842
target_chunks=target_chunks,
@@ -42,7 +46,11 @@ def main():
4246
temp_store=temp_store)
4347

4448
rechunk_plan.execute()
49+
50+
print('Consolidating...')
51+
zarr.convenience.consolidate_metadata(target_store)
4552

53+
print('Done!')
4654

4755
if __name__ == '__main__':
4856
main()

0 commit comments

Comments
 (0)