Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 72e39c8

Browse files
committed
Finally, a full complete draft of #213. Not yet tested
1 parent 04d4fbb commit 72e39c8

File tree

3 files changed

+86
-32
lines changed

3 files changed

+86
-32
lines changed

nowcasting_dataset/data_sources/data_source.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
import logging
44
from dataclasses import InitVar, dataclass
55
from numbers import Number
6+
from pathlib import Path
67
from typing import Iterable, List, Tuple
78

89
import pandas as pd
910
import xarray as xr
1011

12+
import nowcasting_dataset.filesystem.utils as nd_fs_utils
13+
14+
# nowcasting_dataset imports
1115
import nowcasting_dataset.time as nd_time
16+
import nowcasting_dataset.utils as nd_utils
1217
from nowcasting_dataset import square
1318
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
1419
from nowcasting_dataset.dataset.xr_utils import join_dataset_to_batch_dataset
@@ -99,8 +104,7 @@ def sample_period_minutes(self) -> int:
99104
"""
100105
This is the default sample period in minutes.
101106
102-
This functions may be overwritten if
103-
the sample period of the data source is not 5 minutes.
107+
This functions may be overwritten if the sample period of the data source is not 5 minutes.
104108
"""
105109
logging.debug(
106110
"Getting sample_period_minutes default of 5 minutes. "
@@ -112,13 +116,79 @@ def open(self):
112116
"""Open the data source, if necessary.
113117
114118
Called from each worker process. Useful for data sources where the
115-
underlying data source cannot be forked (like Zarr on GCP!).
119+
underlying data source cannot be forked (like Zarr).
116120
117-
Data sources which can be forked safely should call open()
118-
from __init__().
121+
Data sources which can be forked safely should call open() from __init__().
119122
"""
120123
pass
121124

125+
def create_batches(
126+
self,
127+
spatial_and_temporal_locations_of_each_example: pd.DataFrame,
128+
idx_of_first_batch: int,
129+
batch_size: int,
130+
dst_path: Path,
131+
temp_path: Path,
132+
upload_every_n_batches: int,
133+
) -> None:
134+
"""Create multiple batches and save them to disk.
135+
136+
Args:
137+
spatial_and_temporal_locations_of_each_example: A DataFrame where each row specifies
138+
the spatial and temporal location of an example. The number of rows must be
139+
an exact multiple of `batch_size`.
140+
Columns are: t0_datetime_UTC, x_center_OSGB, y_center_OSGB.
141+
idx_of_first_batch: The batch number of the first batch to create.
142+
batch_size: The number of examples per batch.
143+
dst_path: The final destination path for the batches. Must exist.
144+
temp_path: The local temporary path. This is only required when dst_path is a
145+
cloud storage bucket, so files must first be created on the VM's local disk in temp_path
146+
and then uploaded to dst_path every upload_every_n_batches. Must exist. Will be emptied.
147+
upload_every_n_batches: Upload the contents of temp_path to dst_path after this number
148+
of batches have been created. If 0 then will write directly to dst_path.
149+
"""
150+
# Sanity checks:
151+
assert idx_of_first_batch >= 0
152+
assert batch_size > 0
153+
assert len(spatial_and_temporal_locations_of_each_example) % batch_size == 0
154+
assert upload_every_n_batches >= 0
155+
156+
# Figure out where to write batches to:
157+
save_batches_locally_and_upload = upload_every_n_batches > 0
158+
if save_batches_locally_and_upload:
159+
nd_fs_utils.delete_all_files_in_temp_path(temp_path)
160+
path_to_write_to = temp_path if save_batches_locally_and_upload else dst_path
161+
162+
# Loop round each batch:
163+
examples_for_batch = spatial_and_temporal_locations_of_each_example.iloc[:batch_size]
164+
n_batches_processed = 0
165+
while not examples_for_batch.empty:
166+
# Generate batch.
167+
batch = self.get_batch(
168+
t0_datetimes=examples_for_batch.t0_datetime_UTC,
169+
x_locations=examples_for_batch.x_center_OSGB,
170+
y_locations=examples_for_batch.y_center_OSGB,
171+
)
172+
173+
# Save batch to disk.
174+
batch_idx = idx_of_first_batch + n_batches_processed
175+
netcdf_filename = path_to_write_to / nd_utils.get_netcdf_filename(batch_idx)
176+
batch.to_netcdf(netcdf_filename)
177+
178+
# Upload if necessary.
179+
if (
180+
save_batches_locally_and_upload
181+
and n_batches_processed > 0
182+
and n_batches_processed % upload_every_n_batches == 0
183+
):
184+
nd_fs_utils.upload_and_delete_local_files(dst_path, path_to_write_to)
185+
186+
n_batches_processed += 1
187+
188+
# Upload last few batches, if necessary:
189+
if save_batches_locally_and_upload:
190+
nd_fs_utils.upload_and_delete_local_files(dst_path, path_to_write_to)
191+
122192
def get_batch(
123193
self,
124194
t0_datetimes: pd.DatetimeIndex,
@@ -141,14 +211,9 @@ def get_batch(
141211
zipped = zip(t0_datetimes, x_locations, y_locations)
142212
for t0_datetime, x_location, y_location in zipped:
143213
output: xr.Dataset = self.get_example(t0_datetime, x_location, y_location)
144-
145214
examples.append(output)
146215

147-
# could add option here, to save each data source using
148-
# 1. # DataSourceOutput.to_xr_dataset() to make it a dataset
149-
# 2. DataSourceOutput.save_netcdf(), save to netcdf
150-
151-
# get the name of the cls, this could be one of the data sources like Sun
216+
# Get the DataSource class, this could be one of the data sources like Sun
152217
cls = examples[0].__class__
153218

154219
# join the examples together, and cast them to the cls, so that validation can occur

nowcasting_dataset/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def create_batches(self, overwrite_batches: bool) -> None:
343343
for worker_id, (data_source_name, data_source) in enumerate(
344344
self.data_sources.items()
345345
):
346-
# Get indexes of first batch and example; and subset locations_for_split.
346+
# Get indexes of first batch and example. And subset locations_for_split.
347347
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
348348
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
349349
locations = locations_for_split.loc[idx_of_first_example:]

nowcasting_dataset/utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
""" utils functions """
2-
import hashlib
32
import logging
43
import os
54
import re
65
import tempfile
76
from functools import wraps
8-
from pathlib import Path
97
from typing import Optional
108

119
import fsspec.asyn
@@ -35,6 +33,7 @@ def set_fsspec_for_multiprocess() -> None:
3533
fsspec.asyn.loop[0] = None
3634

3735

36+
# TODO: Issue #170. Is this this function still used?
3837
def is_monotonically_increasing(a: Array) -> bool:
3938
""" Check the array is monotonically increasing """
4039
# TODO: Can probably replace with pd.Index.is_monotonic_increasing()
@@ -46,12 +45,14 @@ def is_monotonically_increasing(a: Array) -> bool:
4645
return np.all(np.diff(a) > 0)
4746

4847

48+
# TODO: Issue #170. Is this this function still used?
4949
def is_unique(a: Array) -> bool:
5050
""" Check array has unique values """
5151
# TODO: Can probably replace with pd.Index.is_unique()
5252
return len(a) == len(np.unique(a))
5353

5454

55+
# TODO: Issue #170. Is this this function still used?
5556
def scale_to_0_to_1(a: Array) -> Array:
5657
"""Scale to the range [0, 1]."""
5758
a = a - a.min()
@@ -61,6 +62,7 @@ def scale_to_0_to_1(a: Array) -> Array:
6162
return a
6263

6364

65+
# TODO: Issue #170. Is this this function still used?
6466
def sin_and_cos(df: pd.DataFrame) -> pd.DataFrame:
6567
"""
6668
For every column in df, creates cols for sin and cos of that col.
@@ -94,26 +96,13 @@ def sin_and_cos(df: pd.DataFrame) -> pd.DataFrame:
9496
return output_df
9597

9698

97-
def get_netcdf_filename(batch_idx: int, add_hash: bool = False) -> Path:
98-
"""Generate full filename, excluding path.
99-
100-
Filename includes the first 6 digits of the MD5 hash of the filename,
101-
as recommended by Google Cloud in order to distribute data across
102-
multiple back-end servers.
103-
104-
Add option to turn on and off hashing
105-
106-
"""
107-
filename = f"{batch_idx}.nc"
108-
# In the future we could hash the configuration file, and use this to
109-
# make sure we are saving and loading the same thing.
110-
if add_hash:
111-
hash_of_filename = hashlib.md5(filename.encode()).hexdigest()
112-
filename = f"{hash_of_filename[0:6]}_{filename}"
113-
114-
return filename
99+
def get_netcdf_filename(batch_idx: int) -> str:
100+
"""Generate full filename, excluding path."""
101+
assert 0 <= batch_idx < 1e6
102+
return f"{batch_idx:06d}.nc"
115103

116104

105+
# TODO: Issue #170. Is this this function still used?
117106
def to_numpy(value):
118107
""" Change generic data to numpy"""
119108
if isinstance(value, xr.DataArray):

0 commit comments

Comments
 (0)