3
3
import logging
4
4
from dataclasses import InitVar , dataclass
5
5
from numbers import Number
6
+ from pathlib import Path
6
7
from typing import Iterable , List , Tuple
7
8
8
9
import pandas as pd
9
10
import xarray as xr
10
11
12
+ import nowcasting_dataset .filesystem .utils as nd_fs_utils
13
+
14
+ # nowcasting_dataset imports
11
15
import nowcasting_dataset .time as nd_time
16
+ import nowcasting_dataset .utils as nd_utils
12
17
from nowcasting_dataset import square
13
18
from nowcasting_dataset .data_sources .datasource_output import DataSourceOutput
14
19
from nowcasting_dataset .dataset .xr_utils import join_dataset_to_batch_dataset
@@ -99,8 +104,7 @@ def sample_period_minutes(self) -> int:
99
104
"""
100
105
This is the default sample period in minutes.
101
106
102
- This functions may be overwritten if
103
- the sample period of the data source is not 5 minutes.
107
+ This functions may be overwritten if the sample period of the data source is not 5 minutes.
104
108
"""
105
109
logging .debug (
106
110
"Getting sample_period_minutes default of 5 minutes. "
@@ -112,13 +116,79 @@ def open(self):
112
116
"""Open the data source, if necessary.
113
117
114
118
Called from each worker process. Useful for data sources where the
115
- underlying data source cannot be forked (like Zarr on GCP! ).
119
+ underlying data source cannot be forked (like Zarr).
116
120
117
- Data sources which can be forked safely should call open()
118
- from __init__().
121
+ Data sources which can be forked safely should call open() from __init__().
119
122
"""
120
123
pass
121
124
125
+ def create_batches (
126
+ self ,
127
+ spatial_and_temporal_locations_of_each_example : pd .DataFrame ,
128
+ idx_of_first_batch : int ,
129
+ batch_size : int ,
130
+ dst_path : Path ,
131
+ temp_path : Path ,
132
+ upload_every_n_batches : int ,
133
+ ) -> None :
134
+ """Create multiple batches and save them to disk.
135
+
136
+ Args:
137
+ spatial_and_temporal_locations_of_each_example: A DataFrame where each row specifies
138
+ the spatial and temporal location of an example. The number of rows must be
139
+ an exact multiple of `batch_size`.
140
+ Columns are: t0_datetime_UTC, x_center_OSGB, y_center_OSGB.
141
+ idx_of_first_batch: The batch number of the first batch to create.
142
+ batch_size: The number of examples per batch.
143
+ dst_path: The final destination path for the batches. Must exist.
144
+ temp_path: The local temporary path. This is only required when dst_path is a
145
+ cloud storage bucket, so files must first be created on the VM's local disk in temp_path
146
+ and then uploaded to dst_path every upload_every_n_batches. Must exist. Will be emptied.
147
+ upload_every_n_batches: Upload the contents of temp_path to dst_path after this number
148
+ of batches have been created. If 0 then will write directly to dst_path.
149
+ """
150
+ # Sanity checks:
151
+ assert idx_of_first_batch >= 0
152
+ assert batch_size > 0
153
+ assert len (spatial_and_temporal_locations_of_each_example ) % batch_size == 0
154
+ assert upload_every_n_batches >= 0
155
+
156
+ # Figure out where to write batches to:
157
+ save_batches_locally_and_upload = upload_every_n_batches > 0
158
+ if save_batches_locally_and_upload :
159
+ nd_fs_utils .delete_all_files_in_temp_path (temp_path )
160
+ path_to_write_to = temp_path if save_batches_locally_and_upload else dst_path
161
+
162
+ # Loop round each batch:
163
+ examples_for_batch = spatial_and_temporal_locations_of_each_example .iloc [:batch_size ]
164
+ n_batches_processed = 0
165
+ while not examples_for_batch .empty :
166
+ # Generate batch.
167
+ batch = self .get_batch (
168
+ t0_datetimes = examples_for_batch .t0_datetime_UTC ,
169
+ x_locations = examples_for_batch .x_center_OSGB ,
170
+ y_locations = examples_for_batch .y_center_OSGB ,
171
+ )
172
+
173
+ # Save batch to disk.
174
+ batch_idx = idx_of_first_batch + n_batches_processed
175
+ netcdf_filename = path_to_write_to / nd_utils .get_netcdf_filename (batch_idx )
176
+ batch .to_netcdf (netcdf_filename )
177
+
178
+ # Upload if necessary.
179
+ if (
180
+ save_batches_locally_and_upload
181
+ and n_batches_processed > 0
182
+ and n_batches_processed % upload_every_n_batches == 0
183
+ ):
184
+ nd_fs_utils .upload_and_delete_local_files (dst_path , path_to_write_to )
185
+
186
+ n_batches_processed += 1
187
+
188
+ # Upload last few batches, if necessary:
189
+ if save_batches_locally_and_upload :
190
+ nd_fs_utils .upload_and_delete_local_files (dst_path , path_to_write_to )
191
+
122
192
def get_batch (
123
193
self ,
124
194
t0_datetimes : pd .DatetimeIndex ,
@@ -141,14 +211,9 @@ def get_batch(
141
211
zipped = zip (t0_datetimes , x_locations , y_locations )
142
212
for t0_datetime , x_location , y_location in zipped :
143
213
output : xr .Dataset = self .get_example (t0_datetime , x_location , y_location )
144
-
145
214
examples .append (output )
146
215
147
- # could add option here, to save each data source using
148
- # 1. # DataSourceOutput.to_xr_dataset() to make it a dataset
149
- # 2. DataSourceOutput.save_netcdf(), save to netcdf
150
-
151
- # get the name of the cls, this could be one of the data sources like Sun
216
+ # Get the DataSource class, this could be one of the data sources like Sun
152
217
cls = examples [0 ].__class__
153
218
154
219
# join the examples together, and cast them to the cls, so that validation can occur
0 commit comments