Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Align DataSource constructor args with config YAML field names #285

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sat_filename(use_cloud_data: bool) -> Path:
def sat_data_source(sat_filename: Path):
return SatelliteDataSource(
image_size_pixels=pytest.IMAGE_SIZE_PIXELS,
filename=sat_filename,
zarr_path=sat_filename,
history_minutes=0,
forecast_minutes=5,
channels=("HRV",),
Expand All @@ -60,7 +60,6 @@ def sat_data_source(sat_filename: Path):

@pytest.fixture
def general_data_source():

return MetadataDataSource(history_minutes=0, forecast_minutes=5, object_at_center="GSP")


Expand All @@ -69,7 +68,7 @@ def gsp_data_source():
return GSPDataSource(
image_size_pixels=16,
meters_per_pixel=2000,
filename=Path(__file__).parent.absolute() / "tests" / "data" / "gsp" / "test.zarr",
zarr_path=Path(__file__).parent.absolute() / "tests" / "data" / "gsp" / "test.zarr",
history_minutes=0,
forecast_minutes=30,
)
Expand Down
6 changes: 3 additions & 3 deletions nowcasting_dataset/config/gcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ input_data:
pv:
forecast_minutes: 60
history_minutes: 30
solar_pv_data_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc
solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv
pv_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc
pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv
satellite:
forecast_minutes: 60
history_minutes: 30
sat_channels:
satellite_channels:
- HRV
- IR_016
- IR_039
Expand Down
45 changes: 33 additions & 12 deletions nowcasting_dataset/config/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
""" Configuration model for the dataset
""" Configuration model for the dataset.

All paths must include the protocol prefix. For local files,
it's sufficient to just start with a '/'. For aws, start with 's3://',
for gcp start with 'gs://'.

This file is mostly about _configuring_ the DataSources.

Separate Pydantic models in
`nowcasting_dataset/data_sources/<data_source_name>/<data_source_name>_model.py`
are used to validate the values of the data itself.

"""
from datetime import datetime
from typing import Optional
Expand All @@ -20,6 +27,10 @@
)


IMAGE_SIZE_PIXELS_FIELD = Field(64, description="The number of pixels of the region of interest.")
METERS_PER_PIXEL_FIELD = Field(2000, description="The number of meters per pixel.")


class General(BaseModel):
"""General pydantic model"""

Expand Down Expand Up @@ -71,11 +82,11 @@ def seq_length_5_minutes(self):
class PV(DataSourceMixin):
"""PV configuration model"""

solar_pv_data_filename: str = Field(
pv_filename: str = Field(
"gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc",
description=("The NetCDF file holding the solar PV power timeseries."),
)
solar_pv_metadata_filename: str = Field(
pv_metadata_filename: str = Field(
"gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv",
description="The CSV file describing each PV system.",
)
Expand All @@ -84,6 +95,8 @@ class PV(DataSourceMixin):
description="The number of PV systems samples per example. "
"If there are less in the ROI then the data is padded with zeros. ",
)
pv_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD
pv_meters_per_pixel: int = METERS_PER_PIXEL_FIELD


class Satellite(DataSourceMixin):
Expand All @@ -93,12 +106,11 @@ class Satellite(DataSourceMixin):
"gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr",
description="The path which holds the satellite zarr.",
)

sat_channels: tuple = Field(
satellite_channels: tuple = Field(
SAT_VARIABLE_NAMES, description="the satellite channels that are used"
)

satellite_image_size_pixels: int = Field(64, description="the size of the satellite images")
satellite_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD
satellite_meters_per_pixel: int = METERS_PER_PIXEL_FIELD


class NWP(DataSourceMixin):
Expand All @@ -108,10 +120,9 @@ class NWP(DataSourceMixin):
"gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr",
description="The path which holds the NWP zarr.",
)

nwp_channels: tuple = Field(NWP_VARIABLE_NAMES, description="the channels used in the nwp data")

nwp_image_size_pixels: int = Field(64, description="the size of the nwp images")
nwp_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD
nwp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD


class GSP(DataSourceMixin):
Expand All @@ -123,6 +134,8 @@ class GSP(DataSourceMixin):
description="The number of GSP samples per example. "
"If there are less in the ROI then the data is padded with zeros. ",
)
gsp_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD
gsp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD

@validator("history_minutes")
def history_minutes_divide_by_30(cls, v):
Expand All @@ -144,6 +157,8 @@ class Topographic(DataSourceMixin):
"gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif",
description="Path to the GeoTIFF Topographic data source",
)
topographic_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD
topographic_meters_per_pixel: int = METERS_PER_PIXEL_FIELD


class Sun(DataSourceMixin):
Expand Down Expand Up @@ -179,6 +194,12 @@ class InputData(BaseModel):
description="how many historic minutes are used. "
"This sets the default for all the data sources if they are not set.",
)
data_source_which_defines_geospatial_locations: str = Field(
"gsp",
description=(
"The name of the DataSource which will define the geospatial position of each example."
),
)

@property
def default_seq_length_5_minutes(self):
Expand Down Expand Up @@ -267,8 +288,8 @@ def set_base_path(self, base_path: str):
"""Append base_path to all paths. Mostly used for testing."""
base_path = Pathy(base_path)
path_attrs = [
"pv.solar_pv_data_filename",
"pv.solar_pv_metadata_filename",
"pv.pv_filename",
"pv.pv_metadata_filename",
"satellite.satellite_zarr_path",
"nwp.nwp_zarr_path",
"gsp.gsp_zarr_path",
Expand Down
18 changes: 15 additions & 3 deletions nowcasting_dataset/config/on_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ general:
input_data:
default_forecast_minutes: 120
default_history_minutes: 30
#---------------------- GSP -------------------
gsp:
gsp_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/GSP/v2/pv_gsp.zarr

#---------------------- NWP -------------------
nwp:
nwp_channels:
- t
Expand All @@ -20,11 +23,15 @@ input_data:
- hcc
nwp_image_size_pixels: 64
nwp_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/NWP/UK_Met_Office/UKV/zarr/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr

#---------------------- PV -------------------
pv:
solar_pv_data_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_timeseries_batch.nc
solar_pv_metadata_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_metadata.csv
pv_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_timeseries_batch.nc
pv_metadata_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_metadata.csv

#---------------------- Satellite -------------
satellite:
sat_channels:
satellite_channels:
- HRV
- IR_016
- IR_039
Expand All @@ -39,10 +46,15 @@ input_data:
- WV_073
satellite_image_size_pixels: 64
satellite_zarr_path: /mnt/storage_a/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/satellite/EUMETSAT/SEVIRI_RSS/zarr/all_zarr_int16_single_timestep.zarr

# ------------------------- Sun ------------------------
sun:
sun_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Sun/v0/sun.zarr

# ------------------------- Topographic ----------------
topographic:
topographic_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Topographic/europe_dem_1km_osgb.tif

output_data:
filepath: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v8/
process:
Expand Down
2 changes: 2 additions & 0 deletions nowcasting_dataset/data_sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource
from nowcasting_dataset.data_sources.pv.pv_data_source import PVDataSource
from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource
from nowcasting_dataset.data_sources.topographic.topographic_data_source import (
TopographicDataSource,
)
6 changes: 5 additions & 1 deletion nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int):
super().__post_init__(image_size_pixels, meters_per_pixel)
self._data = None
if self.n_timesteps_per_batch is None:
raise ValueError("n_timesteps_per_batch must be set!")
# Using hacky default for now. The whole concept of n_timesteps_per_batch
# will be removed when #213 is completed.
# TODO: Remove n_timesteps_per_batch when #213 is completed!
self.n_timesteps_per_batch = 16
logger.warning("n_timesteps_per_batch is not set! Using default!")

@property
def data(self):
Expand Down
76 changes: 72 additions & 4 deletions nowcasting_dataset/data_sources/data_source_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,76 @@
import logging

import nowcasting_dataset.time as nd_time
from nowcasting_dataset.dataset.split.split import SplitMethod, split_data, SplitName
import nowcasting_dataset.utils as nd_utils
from nowcasting_dataset.config import model
from nowcasting_dataset import data_sources

logger = logging.getLogger(__name__)


class DataSourceList(list):
"""Hold a list of DataSource objects.

The first DataSource in the list is used to compute the geospatial locations of each example.
Attrs:
data_source_which_defines_geospatial_locations: The DataSource used to compute the
geospatial locations of each example.
"""

@classmethod
def from_config(cls, config_for_all_data_sources: model.InputData):
"""Create a DataSource List from an InputData configuration object.

For each key in each DataSource's configuration object, the string `<data_source_name>_`
is removed from the key before passing to the DataSource constructor. This allows us to
have verbose field names in the configuration YAML files, whilst also using standard
constructor arguments for DataSources.
"""
data_source_name_to_class = {
"pv": data_sources.PVDataSource,
"satellite": data_sources.SatelliteDataSource,
"nwp": data_sources.NWPDataSource,
"gsp": data_sources.GSPDataSource,
"topographic": data_sources.TopographicDataSource,
"sun": data_sources.SunDataSource,
}
data_source_list = cls([])
for data_source_name, data_source_class in data_source_name_to_class.items():
logger.debug(f"Creating {data_source_name} DataSource object.")
config_for_data_source = getattr(config_for_all_data_sources, data_source_name)
if config_for_data_source is None:
logger.info(f"No configuration found for {data_source_name}.")
continue
config_for_data_source = config_for_data_source.dict()

# Strip `<data_source_name>_` from the config option field names.
config_for_data_source = nd_utils.remove_regex_pattern_from_keys(
config_for_data_source, pattern_to_remove=f"^{data_source_name}_"
)

try:
data_source = data_source_class(**config_for_data_source)
except Exception:
logger.exception(f"Exception whilst instantiating {data_source_name}!")
raise
data_source_list.append(data_source)
if (
data_source_name
== config_for_all_data_sources.data_source_which_defines_geospatial_locations
):
data_source_list.data_source_which_defines_geospatial_locations = data_source
logger.info(
f"DataSource {data_source_name} set as"
" data_source_which_defines_geospatial_locations"
)

try:
_ = data_source_list.data_source_which_defines_geospatial_locations
except AttributeError:
logger.warning(
"No DataSource configured as data_source_which_defines_geospatial_locations!"
)
return data_source_list

def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex:
"""
Compute the intersection of the t0 datetimes available across all DataSources.
Expand Down Expand Up @@ -71,9 +130,18 @@ def sample_spatial_and_temporal_locations_for_examples(
Each row of each the DataFrame specifies the position of each example, using
columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
"""
data_source_which_defines_geo_position = self[0]
# This code is for backwards-compatibility with code which expects the first DataSource
# in the list to be used to define which DataSource defines the spatial location.
# TODO: Remove this try block after implementing issue #213.
try:
data_source_which_defines_geospatial_locations = (
self.data_source_which_defines_geospatial_locations
)
except AttributeError:
data_source_which_defines_geospatial_locations = self[0]

shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
x_locations, y_locations = data_source_which_defines_geo_position.get_locations(
x_locations, y_locations = data_source_which_defines_geospatial_locations.get_locations(
shuffled_t0_datetimes
)
return pd.DataFrame(
Expand Down
4 changes: 2 additions & 2 deletions nowcasting_dataset/data_sources/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def satellite_fake(
batch_size=32,
seq_length_5=19,
satellite_image_size_pixels=64,
number_sat_channels=7,
number_satellite_channels=7,
) -> Satellite:
""" Create fake data """
# make batch of arrays
xr_arrays = [
create_image_array(
seq_length_5=seq_length_5,
image_size_pixels=satellite_image_size_pixels,
number_channels=number_sat_channels,
number_channels=number_satellite_channels,
)
for _ in range(batch_size)
]
Expand Down
Loading