Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 0fbe834

Browse files
committed
normalise NWP params. Closes #3
1 parent 39ca983 commit 0fbe834

File tree

6 files changed

+1088
-374
lines changed

6 files changed

+1088
-374
lines changed

notebooks/benchmark_loading_speed.ipynb

Lines changed: 89 additions & 205 deletions
Large diffs are not rendered by default.

notebooks/testing_NWPDataSource.ipynb

Lines changed: 928 additions & 152 deletions
Large diffs are not rendered by default.

nowcasting_dataset/data_sources/nwp_data_source.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,31 @@
1313
_LOG = logging.getLogger('nowcasting_dataset')
1414

1515

16+
nwp_ds.data - xr.DataArray(data=std.values, dims=('variable', ), coords=dict(variable=std['variable'].values))
17+
18+
NWP_VARIABLE_NAMES = (
19+
't', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc')
20+
21+
# Means computed with
22+
# nwp_ds = NWPDataSource(...)
23+
# nwp_ds.open()
24+
# mean = nwp_ds.data.isel(init_time=slice(0, 10)).mean(dim=['step', 'x', 'init_time', 'y']).compute()
25+
NWP_MEAN = xr.DataArray(
26+
data=(
27+
2.8041010e+02, 1.6854691e+01, 6.7529683e-05, 8.1832832e+01,
28+
7.1233767e-03, 8.8566933e+00, 4.3474598e+04, 4.9820110e+01,
29+
4.8095409e+01, 4.2833260e+01),
30+
dims=('variable', ),
31+
coords={'variable': NWP_VARIABLE_NAMES})
32+
33+
NWP_STD = xr.DataArray(
34+
data=(
35+
2.5812180e+00, 4.1278820e+01, 2.7507244e-04, 9.0967312e+00,
36+
1.4110464e-01, 4.3616886e+00, 2.3853148e+04, 3.8900299e+01,
37+
4.2830105e+01, 4.2778091e+01),
38+
dims=('variable', ),
39+
coords={'variable': NWP_VARIABLE_NAMES})
40+
1641
@dataclass
1742
class NWPDataSource(ZarrDataSource):
1843
"""
@@ -38,9 +63,7 @@ class NWPDataSource(ZarrDataSource):
3863
mcc : Medium-level cloud cover in %.
3964
hcc : High-level cloud cover in %.
4065
"""
41-
channels: Optional[Iterable[str]] = (
42-
't', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc')
43-
max_step: int = 3 #: Max forecast timesteps to load from NWPs.
66+
channels: Optional[Iterable[str]] = NWP_VARIABLE_NAMES
4467
image_size_pixels: InitVar[int] = 2
4568
meters_per_pixel: InitVar[int] = 2_000
4669

@@ -58,8 +81,8 @@ def open(self) -> None:
5881
# call open() _after_ creating separate processes.
5982
data = self._open_data()
6083
data = data[list(self.channels)].to_array()
61-
#self._data = data.sel(
62-
# step=slice(pd.Timedelta(0), pd.Timedelta(hours=self.max_step + 1)))
84+
data -= NWP_MEAN
85+
data /= NWP_STD
6386
self._data = data
6487

6588
def _open_data(self) -> xr.DataArray:
@@ -100,8 +123,16 @@ def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray:
100123

101124
# Get the most recent NWP initialisation time for each
102125
# target_time_hourly.
103-
init_times = self.data.sel(
104-
init_time=target_times_hourly, method='ffill').init_time.values
126+
try:
127+
init_times = self.data.sel(
128+
init_time=target_times_hourly, method='ffill').init_time.values
129+
except Exception as e:
130+
is_increasing = utils.is_monotonically_increasing(self.data.init_time.astype(int))
131+
is_unique = utils.is_unique(self.data.init_time)
132+
_LOG.exception(
133+
f'Exception! start_hourly={start_hourly}, t0_hourly={t0_hourly}, end_hourly={end_hourly}, '
134+
f'target_times_hourly={target_times_hourly}, {e}, is_increasing={is_increasing}, is_unique={is_unique}')
135+
raise
105136

106137
# Find the NWP init time for just the 'future' portion of the example.
107138
init_time_future = init_times[target_times_hourly == t0_hourly]
@@ -150,7 +181,7 @@ def datetime_index(self) -> pd.DatetimeIndex:
150181
nwp = self._open_data()
151182
else:
152183
nwp = self._data
153-
target_times = nwp['init_time'] + nwp['step'][:self.max_step]
184+
target_times = nwp['init_time'] + nwp['step'][:3]
154185
target_times = target_times.values.flatten()
155186
target_times = np.unique(target_times)
156187
target_times = np.sort(target_times)

nowcasting_dataset/data_sources/pv_data_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_example(
107107
" (but not at the identical location to) x_meters_center and"
108108
" y_meters_center.")
109109

110-
selected_pv_power = self._get_timestep_with_cache(t0_dt)
110+
selected_pv_power = self._get_cached_time_slice(t0_dt)
111111
pv_system_ids = selected_pv_power.columns.intersection(pv_system_ids)
112112
assert len(pv_system_ids) > 0
113113

nowcasting_dataset/datamodule.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __post_init__(self):
4747
# Plus 1 because neither history_len nor forecast_len include t0.
4848
self._total_seq_len = self.history_len + self.forecast_len + 1
4949
self.contiguous_dataset = None
50+
if self.num_workers == 0:
51+
self.prefetch_factor = 2 # Set to default when not using multiprocessing.
5052

5153
def prepare_data(self) -> None:
5254
# Satellite data
@@ -141,14 +143,24 @@ def setup(self, stage='fit'):
141143
self.train_dataset = dataset.NowcastingDataset(
142144
t0_datetimes=self.train_t0_datetimes,
143145
data_sources=self.data_sources,
144-
n_batches_per_epoch_per_worker=1024 // self.num_workers,
146+
n_batches_per_epoch_per_worker=self._n_batches_per_epoch_per_worker(1024),
145147
**self._common_dataset_params())
146148
self.val_dataset = dataset.NowcastingDataset(
147149
t0_datetimes=self.val_t0_datetimes,
148150
data_sources=self.data_sources,
149-
n_batches_per_epoch_per_worker=32 // self.num_workers,
151+
n_batches_per_epoch_per_worker=self._n_batches_per_epoch_per_worker(32),
150152
**self._common_dataset_params())
151-
153+
154+
if self.num_workers == 0:
155+
self.train_dataset.per_worker_init(worker_id=0)
156+
self.val_dataset.per_worker_init(worker_id=0)
157+
158+
def _n_batches_per_epoch_per_worker(self, n_batches_per_epoch: int) -> int:
159+
if self.num_workers > 0:
160+
return n_batches_per_epoch // self.num_workers
161+
else:
162+
return n_batches_per_epoch
163+
152164
def _split_data(self):
153165
"""Sets self.train_t0_datetimes and self.val_t0_datetimes."""
154166
self._check_has_prepared_data()
@@ -184,8 +196,10 @@ def contiguous_dataloader(self) -> torch.utils.data.DataLoader:
184196
self.contiguous_dataset = dataset.ContiguousNowcastingDataset(
185197
t0_datetimes=self.val_t0_datetimes,
186198
data_sources=data_sources,
187-
n_batches_per_epoch_per_worker=32 // self.num_workers,
199+
n_batches_per_epoch_per_worker=self._n_batches_per_epoch_per_worker(32),
188200
**self._common_dataset_params())
201+
if self.num_workers == 0:
202+
self.contiguous_dataset.per_worker_init(worker_id=0)
189203
return torch.utils.data.DataLoader(
190204
self.contiguous_dataset, **self._common_dataloader_params())
191205

nowcasting_dataset/dataset.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from dataclasses import dataclass
88
import torch
99
from concurrent import futures
10+
import logging
11+
12+
_LOG = logging.getLogger('nowcasting_dataset')
1013

1114

1215
@dataclass
@@ -115,10 +118,16 @@ def _get_example(
115118

116119
example = nowcasting_dataset.example.Example(t0_dt=t0_dt)
117120
for data_source in self.data_sources:
118-
example_from_source = data_source.get_example(
119-
t0_dt=t0_dt,
120-
x_meters_center=x_meters_center,
121-
y_meters_center=y_meters_center)
121+
try:
122+
example_from_source = data_source.get_example(
123+
t0_dt=t0_dt,
124+
x_meters_center=x_meters_center,
125+
y_meters_center=y_meters_center)
126+
except Exception as e:
127+
_LOG.exception(
128+
f'Exception! t0_dt={t0_dt}, x_meters_center={x_meters_center}, y_meters_center={y_meters_center}, {e}')
129+
raise
130+
122131
example.update(example_from_source)
123132
example = nowcasting_dataset.example.to_numpy(example)
124133
return example

0 commit comments

Comments
 (0)