diff --git a/conftest.py b/conftest.py index e503609c..b66a63b7 100644 --- a/conftest.py +++ b/conftest.py @@ -50,7 +50,7 @@ def sat_filename(use_cloud_data: bool) -> Path: def sat_data_source(sat_filename: Path): return SatelliteDataSource( image_size_pixels=pytest.IMAGE_SIZE_PIXELS, - filename=sat_filename, + zarr_path=sat_filename, history_minutes=0, forecast_minutes=5, channels=("HRV",), @@ -60,7 +60,6 @@ def sat_data_source(sat_filename: Path): @pytest.fixture def general_data_source(): - return MetadataDataSource(history_minutes=0, forecast_minutes=5, object_at_center="GSP") @@ -69,7 +68,7 @@ def gsp_data_source(): return GSPDataSource( image_size_pixels=16, meters_per_pixel=2000, - filename=Path(__file__).parent.absolute() / "tests" / "data" / "gsp" / "test.zarr", + zarr_path=Path(__file__).parent.absolute() / "tests" / "data" / "gsp" / "test.zarr", history_minutes=0, forecast_minutes=30, ) diff --git a/nowcasting_dataset/config/gcp.yaml b/nowcasting_dataset/config/gcp.yaml index 7e2d447f..dd3a486b 100644 --- a/nowcasting_dataset/config/gcp.yaml +++ b/nowcasting_dataset/config/gcp.yaml @@ -27,12 +27,12 @@ input_data: pv: forecast_minutes: 60 history_minutes: 30 - solar_pv_data_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc - solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv + pv_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc + pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv satellite: forecast_minutes: 60 history_minutes: 30 - sat_channels: + satellite_channels: - HRV - IR_016 - IR_039 diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 95d7ed81..3d22f4a6 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -1,8 +1,15 @@ -""" Configuration model for the dataset +""" Configuration model for the dataset. All paths must include the protocol prefix. For local files, it's sufficient to just start with a '/'. For aws, start with 's3://', for gcp start with 'gs://'. + +This file is mostly about _configuring_ the DataSources. + +Separate Pydantic models in +`nowcasting_dataset/data_sources//_model.py` +are used to validate the values of the data itself. + """ from datetime import datetime from typing import Optional @@ -20,6 +27,10 @@ ) +IMAGE_SIZE_PIXELS_FIELD = Field(64, description="The number of pixels of the region of interest.") +METERS_PER_PIXEL_FIELD = Field(2000, description="The number of meters per pixel.") + + class General(BaseModel): """General pydantic model""" @@ -71,11 +82,11 @@ def seq_length_5_minutes(self): class PV(DataSourceMixin): """PV configuration model""" - solar_pv_data_filename: str = Field( + pv_filename: str = Field( "gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc", description=("The NetCDF file holding the solar PV power timeseries."), ) - solar_pv_metadata_filename: str = Field( + pv_metadata_filename: str = Field( "gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv", description="The CSV file describing each PV system.", ) @@ -84,6 +95,8 @@ class PV(DataSourceMixin): description="The number of PV systems samples per example. " "If there are less in the ROI then the data is padded with zeros. ", ) + pv_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD + pv_meters_per_pixel: int = METERS_PER_PIXEL_FIELD class Satellite(DataSourceMixin): @@ -93,12 +106,11 @@ class Satellite(DataSourceMixin): "gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr", description="The path which holds the satellite zarr.", ) - - sat_channels: tuple = Field( + satellite_channels: tuple = Field( SAT_VARIABLE_NAMES, description="the satellite channels that are used" ) - - satellite_image_size_pixels: int = Field(64, description="the size of the satellite images") + satellite_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD + satellite_meters_per_pixel: int = METERS_PER_PIXEL_FIELD class NWP(DataSourceMixin): @@ -108,10 +120,9 @@ class NWP(DataSourceMixin): "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr", description="The path which holds the NWP zarr.", ) - nwp_channels: tuple = Field(NWP_VARIABLE_NAMES, description="the channels used in the nwp data") - - nwp_image_size_pixels: int = Field(64, description="the size of the nwp images") + nwp_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD + nwp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD class GSP(DataSourceMixin): @@ -123,6 +134,8 @@ class GSP(DataSourceMixin): description="The number of GSP samples per example. " "If there are less in the ROI then the data is padded with zeros. ", ) + gsp_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD + gsp_meters_per_pixel: int = METERS_PER_PIXEL_FIELD @validator("history_minutes") def history_minutes_divide_by_30(cls, v): @@ -144,6 +157,8 @@ class Topographic(DataSourceMixin): "gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif", description="Path to the GeoTIFF Topographic data source", ) + topographic_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD + topographic_meters_per_pixel: int = METERS_PER_PIXEL_FIELD class Sun(DataSourceMixin): @@ -179,6 +194,12 @@ class InputData(BaseModel): description="how many historic minutes are used. " "This sets the default for all the data sources if they are not set.", ) + data_source_which_defines_geospatial_locations: str = Field( + "gsp", + description=( + "The name of the DataSource which will define the geospatial position of each example." + ), + ) @property def default_seq_length_5_minutes(self): @@ -267,8 +288,8 @@ def set_base_path(self, base_path: str): """Append base_path to all paths. Mostly used for testing.""" base_path = Pathy(base_path) path_attrs = [ - "pv.solar_pv_data_filename", - "pv.solar_pv_metadata_filename", + "pv.pv_filename", + "pv.pv_metadata_filename", "satellite.satellite_zarr_path", "nwp.nwp_zarr_path", "gsp.gsp_zarr_path", diff --git a/nowcasting_dataset/config/on_premises.yaml b/nowcasting_dataset/config/on_premises.yaml index bae87205..226254e7 100644 --- a/nowcasting_dataset/config/on_premises.yaml +++ b/nowcasting_dataset/config/on_premises.yaml @@ -4,8 +4,11 @@ general: input_data: default_forecast_minutes: 120 default_history_minutes: 30 + #---------------------- GSP ------------------- gsp: gsp_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/GSP/v2/pv_gsp.zarr + + #---------------------- NWP ------------------- nwp: nwp_channels: - t @@ -20,11 +23,15 @@ input_data: - hcc nwp_image_size_pixels: 64 nwp_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/NWP/UK_Met_Office/UKV/zarr/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr + + #---------------------- PV ------------------- pv: - solar_pv_data_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_timeseries_batch.nc - solar_pv_metadata_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_metadata.csv + pv_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_timeseries_batch.nc + pv_metadata_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/PVOutput.org/UK_PV_metadata.csv + + #---------------------- Satellite ------------- satellite: - sat_channels: + satellite_channels: - HRV - IR_016 - IR_039 @@ -39,10 +46,15 @@ input_data: - WV_073 satellite_image_size_pixels: 64 satellite_zarr_path: /mnt/storage_a/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/satellite/EUMETSAT/SEVIRI_RSS/zarr/all_zarr_int16_single_timestep.zarr + + # ------------------------- Sun ------------------------ sun: sun_zarr_path: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Sun/v0/sun.zarr + + # ------------------------- Topographic ---------------- topographic: topographic_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Topographic/europe_dem_1km_osgb.tif + output_data: filepath: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v8/ process: diff --git a/nowcasting_dataset/data_sources/__init__.py b/nowcasting_dataset/data_sources/__init__.py index f4cf7556..9f82670f 100644 --- a/nowcasting_dataset/data_sources/__init__.py +++ b/nowcasting_dataset/data_sources/__init__.py @@ -4,6 +4,8 @@ from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource from nowcasting_dataset.data_sources.pv.pv_data_source import PVDataSource from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource +from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource +from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.data_sources.topographic.topographic_data_source import ( TopographicDataSource, ) diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 0b2cac13..0288fae7 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -250,7 +250,11 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): super().__post_init__(image_size_pixels, meters_per_pixel) self._data = None if self.n_timesteps_per_batch is None: - raise ValueError("n_timesteps_per_batch must be set!") + # Using hacky default for now. The whole concept of n_timesteps_per_batch + # will be removed when #213 is completed. + # TODO: Remove n_timesteps_per_batch when #213 is completed! + self.n_timesteps_per_batch = 16 + logger.warning("n_timesteps_per_batch is not set! Using default!") @property def data(self): diff --git a/nowcasting_dataset/data_sources/data_source_list.py b/nowcasting_dataset/data_sources/data_source_list.py index f5e579ad..a5b9f55a 100644 --- a/nowcasting_dataset/data_sources/data_source_list.py +++ b/nowcasting_dataset/data_sources/data_source_list.py @@ -5,7 +5,9 @@ import logging import nowcasting_dataset.time as nd_time -from nowcasting_dataset.dataset.split.split import SplitMethod, split_data, SplitName +import nowcasting_dataset.utils as nd_utils +from nowcasting_dataset.config import model +from nowcasting_dataset import data_sources logger = logging.getLogger(__name__) @@ -13,9 +15,66 @@ class DataSourceList(list): """Hold a list of DataSource objects. - The first DataSource in the list is used to compute the geospatial locations of each example. + Attrs: + data_source_which_defines_geospatial_locations: The DataSource used to compute the + geospatial locations of each example. """ + @classmethod + def from_config(cls, config_for_all_data_sources: model.InputData): + """Create a DataSource List from an InputData configuration object. + + For each key in each DataSource's configuration object, the string `_` + is removed from the key before passing to the DataSource constructor. This allows us to + have verbose field names in the configuration YAML files, whilst also using standard + constructor arguments for DataSources. + """ + data_source_name_to_class = { + "pv": data_sources.PVDataSource, + "satellite": data_sources.SatelliteDataSource, + "nwp": data_sources.NWPDataSource, + "gsp": data_sources.GSPDataSource, + "topographic": data_sources.TopographicDataSource, + "sun": data_sources.SunDataSource, + } + data_source_list = cls([]) + for data_source_name, data_source_class in data_source_name_to_class.items(): + logger.debug(f"Creating {data_source_name} DataSource object.") + config_for_data_source = getattr(config_for_all_data_sources, data_source_name) + if config_for_data_source is None: + logger.info(f"No configuration found for {data_source_name}.") + continue + config_for_data_source = config_for_data_source.dict() + + # Strip `_` from the config option field names. + config_for_data_source = nd_utils.remove_regex_pattern_from_keys( + config_for_data_source, pattern_to_remove=f"^{data_source_name}_" + ) + + try: + data_source = data_source_class(**config_for_data_source) + except Exception: + logger.exception(f"Exception whilst instantiating {data_source_name}!") + raise + data_source_list.append(data_source) + if ( + data_source_name + == config_for_all_data_sources.data_source_which_defines_geospatial_locations + ): + data_source_list.data_source_which_defines_geospatial_locations = data_source + logger.info( + f"DataSource {data_source_name} set as" + " data_source_which_defines_geospatial_locations" + ) + + try: + _ = data_source_list.data_source_which_defines_geospatial_locations + except AttributeError: + logger.warning( + "No DataSource configured as data_source_which_defines_geospatial_locations!" + ) + return data_source_list + def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex: """ Compute the intersection of the t0 datetimes available across all DataSources. @@ -71,9 +130,18 @@ def sample_spatial_and_temporal_locations_for_examples( Each row of each the DataFrame specifies the position of each example, using columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. """ - data_source_which_defines_geo_position = self[0] + # This code is for backwards-compatibility with code which expects the first DataSource + # in the list to be used to define which DataSource defines the spatial location. + # TODO: Remove this try block after implementing issue #213. + try: + data_source_which_defines_geospatial_locations = ( + self.data_source_which_defines_geospatial_locations + ) + except AttributeError: + data_source_which_defines_geospatial_locations = self[0] + shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples) - x_locations, y_locations = data_source_which_defines_geo_position.get_locations( + x_locations, y_locations = data_source_which_defines_geospatial_locations.get_locations( shuffled_t0_datetimes ) return pd.DataFrame( diff --git a/nowcasting_dataset/data_sources/fake.py b/nowcasting_dataset/data_sources/fake.py index b4795aff..4025e4a7 100644 --- a/nowcasting_dataset/data_sources/fake.py +++ b/nowcasting_dataset/data_sources/fake.py @@ -112,7 +112,7 @@ def satellite_fake( batch_size=32, seq_length_5=19, satellite_image_size_pixels=64, - number_sat_channels=7, + number_satellite_channels=7, ) -> Satellite: """ Create fake data """ # make batch of arrays @@ -120,7 +120,7 @@ def satellite_fake( create_image_array( seq_length_5=seq_length_5, image_size_pixels=satellite_image_size_pixels, - number_channels=number_sat_channels, + number_channels=number_satellite_channels, ) for _ in range(batch_size) ] diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index caf680e0..87fb00af 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -43,8 +43,8 @@ class GSPDataSource(ImageDataSource): The region of interest is defined by `image_size_pixels` and `meters_per_pixel`. """ - # filename of where the gsp data is stored - filename: Union[str, Path] + # zarr_path of where the gsp data is stored + zarr_path: Union[str, Path] # start datetime, this can be None start_dt: Optional[datetime] = None # end datetime, this can be None @@ -87,7 +87,7 @@ def load(self): # load gsp data from file / gcp self.gsp_power = load_solar_gsp_data( - self.filename, start_dt=self.start_dt, end_dt=self.end_dt + self.zarr_path, start_dt=self.start_dt, end_dt=self.end_dt ) # drop any gsp below 20 MW (or set threshold). This is to get rid of any small GSP where @@ -386,7 +386,7 @@ def drop_gsp_by_threshold(gsp_power: pd.DataFrame, meta_data: pd.DataFrame, thre def load_solar_gsp_data( - filename: Union[str, Path], + zarr_path: Union[str, Path], start_dt: Optional[datetime] = None, end_dt: Optional[datetime] = None, ) -> pd.DataFrame: @@ -394,17 +394,17 @@ def load_solar_gsp_data( Load solar PV GSP data Args: - filename: filename of file to be loaded, can put 'gs://' files in here too + zarr_path: zarr_path of file to be loaded, can put 'gs://' files in here too start_dt: the start datetime, which to trim the data to end_dt: the end datetime, which to trim the data to Returns: dataframe of pv data """ - logger.debug(f"Loading Solar GSP Data from GCS {filename} from {start_dt} to {end_dt}") + logger.debug(f"Loading Solar GSP Data from GCS {zarr_path} from {start_dt} to {end_dt}") # Open data - it may be quicker to open byte file first, but decided just to keep it # like this at the moment. - gsp_power = xr.open_dataset(filename, engine="zarr") + gsp_power = xr.open_dataset(zarr_path, engine="zarr") gsp_power = gsp_power.sel(datetime_gmt=slice(start_dt, end_dt)) # make normalized data diff --git a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py index 50f12361..eb596801 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -13,20 +13,16 @@ from nowcasting_dataset.data_sources.data_source import ZarrDataSource from nowcasting_dataset.data_sources.nwp.nwp_model import NWP from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset +from nowcasting_dataset.consts import NWP_VARIABLE_NAMES _LOG = logging.getLogger(__name__) -from nowcasting_dataset.consts import NWP_VARIABLE_NAMES - @dataclass class NWPDataSource(ZarrDataSource): """ NWP Data Source (Numerical Weather Predictions) - Args (for init): - filename: The base path in which we find '2018_1-6', etc. - Attributes: _data: xr.DataArray of Numerical Weather Predictions, opened by open(). x is left-to-right. @@ -47,7 +43,7 @@ class NWPDataSource(ZarrDataSource): hcc : High-level cloud cover in %. """ - filename: str = None + zarr_path: str = None channels: Optional[Iterable[str]] = NWP_VARIABLE_NAMES image_size_pixels: InitVar[int] = 2 meters_per_pixel: InitVar[int] = 2_000 @@ -146,7 +142,7 @@ def get_batch( return NWP(output) def _open_data(self) -> xr.DataArray: - return open_nwp(self.filename, consolidated=self.consolidated) + return open_nwp(self.zarr_path, consolidated=self.consolidated) def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: """ @@ -212,20 +208,22 @@ def datetime_index(self) -> pd.DatetimeIndex: return resampler.ffill(limit=11).dropna().index -def open_nwp(filename: str, consolidated: bool) -> xr.Dataset: +def open_nwp(zarr_path: str, consolidated: bool) -> xr.Dataset: """ Open The NWP data Args: - filename: filename must start with 'gs://' if it's on GCP. + zarr_path: zarr_path must start with 'gs://' if it's on GCP. consolidated: consolidate the zarr file? Returns: nwp data """ - _LOG.debug("Opening NWP data: %s", filename) + _LOG.debug("Opening NWP data: %s", zarr_path) utils.set_fsspec_for_multiprocess() - nwp = xr.open_dataset(filename, engine="zarr", consolidated=consolidated, mode="r", chunks=None) + nwp = xr.open_dataset( + zarr_path, engine="zarr", consolidated=consolidated, mode="r", chunks=None + ) # Sanity check. # TODO: Replace this with diff --git a/nowcasting_dataset/data_sources/pv/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py index 54894cba..6d49cfac 100644 --- a/nowcasting_dataset/data_sources/pv/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -29,7 +29,11 @@ @dataclass class PVDataSource(ImageDataSource): - """ PV Data Source """ + """PV Data Source. + + This inherits from ImageDataSource so PVDataSource can select a geospatial region of interest + defined by image_size_pixels and meters_per_pixel. + """ filename: Union[str, Path] metadata_filename: Union[str, Path] @@ -313,6 +317,7 @@ def datetime_index(self) -> pd.DatetimeIndex: return self.pv_power.index +# TODO: Enable this function to load from any compute environment. See issue #286. def load_solar_pv_data_from_gcs( filename: Union[str, Path], start_dt: Optional[datetime.datetime] = None, diff --git a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py index 3bc36869..ecb69c4f 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -25,10 +25,10 @@ class SatelliteDataSource(ZarrDataSource): """ Satellite Data Source - filename: Must start with 'gs://' if on GCP. + zarr_path: Must start with 'gs://' if on GCP. """ - filename: str = None + zarr_path: str = None channels: Optional[Iterable[str]] = SAT_VARIABLE_NAMES image_size_pixels: InitVar[int] = 128 meters_per_pixel: InitVar[int] = 2_000 @@ -58,7 +58,7 @@ def open(self) -> None: self._data = self._data.sel(variable=list(self.channels)) def _open_data(self) -> xr.DataArray: - return open_sat_data(filename=self.filename, consolidated=self.consolidated) + return open_sat_data(zarr_path=self.zarr_path, consolidated=self.consolidated) def get_batch( self, @@ -156,17 +156,17 @@ def datetime_index(self, remove_night: bool = True) -> pd.DatetimeIndex: return datetime_index -def open_sat_data(filename: str, consolidated: bool) -> xr.DataArray: +def open_sat_data(zarr_path: str, consolidated: bool) -> xr.DataArray: """Lazily opens the Zarr store. Adds 1 minute to the 'time' coordinates, so the timestamps are at 00, 05, ..., 55 past the hour. Args: - filename: Cloud URL or local path. If GCP URL, must start with 'gs://' + zarr_path: Cloud URL or local path. If GCP URL, must start with 'gs://' consolidated: Whether or not the Zarr metadata is consolidated. """ - _LOG.debug("Opening satellite data: %s", filename) + _LOG.debug("Opening satellite data: %s", zarr_path) # We load using chunks=None so xarray *doesn't* use Dask to # load the Zarr chunks from disk. Using Dask to load the data @@ -174,7 +174,7 @@ def open_sat_data(filename: str, consolidated: bool) -> xr.DataArray: # about a million chunks. # See https://github.com/openclimatefix/nowcasting_dataset/issues/23 dataset = xr.open_dataset( - filename, engine="zarr", consolidated=consolidated, mode="r", chunks=None + zarr_path, engine="zarr", consolidated=consolidated, mode="r", chunks=None ) data_array = dataset["stacked_eumetsat_data"] diff --git a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py index 015b366d..ba01b561 100644 --- a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py +++ b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py @@ -96,14 +96,14 @@ def get_azimuth_and_elevation( return azimuth.round(2), elevation.round(2) -def save_to_zarr(azimuth: pd.DataFrame, elevation: pd.DataFrame, filename: Union[str, Path]): +def save_to_zarr(azimuth: pd.DataFrame, elevation: pd.DataFrame, zarr_path: Union[str, Path]): """ Save azimuth and elevation to zarr file Args: azimuth: data to be saved elevation: data to be saved - filename: the file name where it should be save, can be local of gcs + zarr_path: the file name where it should be save, can be local of gcs """ # change pandas dataframe to xr Dataset @@ -121,11 +121,11 @@ def save_to_zarr(azimuth: pd.DataFrame, elevation: pd.DataFrame, filename: Union } # save to file - merged_ds.to_zarr(filename, mode="w", encoding=encoding) + merged_ds.to_zarr(zarr_path, mode="w", encoding=encoding) def load_from_zarr( - filename: Union[str, Path], + zarr_path: Union[str, Path], start_dt: Optional[datetime.datetime] = None, end_dt: Optional[datetime.datetime] = None, ) -> (pd.DataFrame, pd.DataFrame): @@ -133,7 +133,7 @@ def load_from_zarr( Load sun data Args: - filename: the filename to be loaded, can be local or gcs + zarr_path: the zarr_path to be loaded, can be local or gcs start_dt: optional start datetime. Both start and end need to be set to be used. end_dt: optional end datetime. Both start and end need to be set to be used. @@ -148,7 +148,7 @@ def load_from_zarr( # in the first 'with' block, and delete the second 'with' block. # But that takes 1 minute to load the data, where as loading into memory # first and then loading from memory takes 23 seconds! - sun = xr.open_dataset(filename, engine="zarr") + sun = xr.open_dataset(zarr_path, engine="zarr") if (start_dt is not None) and (end_dt is not None): sun = sun.sel(datetime_gmt=slice(start_dt, end_dt)) diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index af05839b..c193fcf6 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -19,7 +19,7 @@ class SunDataSource(DataSource): """Add azimuth and elevation angles of the sun.""" - filename: Union[str, Path] + zarr_path: Union[str, Path] start_dt: Optional[datetime] = None end_dt: Optional[datetime] = None @@ -79,7 +79,7 @@ def get_example( def _load(self): self.azimuth, self.elevation = load_from_zarr( - filename=self.filename, start_dt=self.start_dt, end_dt=self.end_dt + zarr_path=self.zarr_path, start_dt=self.start_dt, end_dt=self.end_dt ) def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: diff --git a/nowcasting_dataset/dataset/batch.py b/nowcasting_dataset/dataset/batch.py index 78d1096c..7bb814f1 100644 --- a/nowcasting_dataset/dataset/batch.py +++ b/nowcasting_dataset/dataset/batch.py @@ -99,7 +99,9 @@ def fake(configuration: Configuration): batch_size=batch_size, seq_length_5=configuration.input_data.satellite.seq_length_5_minutes, satellite_image_size_pixels=satellite_image_size_pixels, - number_sat_channels=len(configuration.input_data.satellite.sat_channels), + number_satellite_channels=len( + configuration.input_data.satellite.satellite_channels + ), ), nwp=nwp_fake( batch_size=batch_size, diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index af8bcef3..de119910 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -112,7 +112,7 @@ def prepare_data(self) -> None: n_timesteps_per_batch = self.batch_size // self.n_samples_per_timestep self.sat_data_source = data_sources.SatelliteDataSource( - filename=self.sat_filename, + zarr_path=self.sat_filename, image_size_pixels=self.satellite_image_size_pixels, meters_per_pixel=self.meters_per_pixel, history_minutes=self.history_minutes, @@ -144,7 +144,7 @@ def prepare_data(self) -> None: if self.gsp_filename is not None: self.gsp_data_source = GSPDataSource( - filename=self.gsp_filename, + zarr_path=self.gsp_filename, start_dt=sat_datetimes[0], end_dt=sat_datetimes[-1], history_minutes=self.history_minutes, @@ -161,7 +161,7 @@ def prepare_data(self) -> None: # NWP data if self.nwp_base_path is not None: self.nwp_data_source = data_sources.NWPDataSource( - filename=self.nwp_base_path, + zarr_path=self.nwp_base_path, image_size_pixels=self.nwp_image_size_pixels, meters_per_pixel=self.meters_per_pixel, history_minutes=self.history_minutes, @@ -187,7 +187,7 @@ def prepare_data(self) -> None: # Sun data if self.sun_filename is not None: self.sun_data_source = SunDataSource( - filename=self.sun_filename, + zarr_path=self.sun_filename, history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes, ) diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py index ed874a11..3e2c59e4 100644 --- a/nowcasting_dataset/dataset/datasets.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -23,6 +23,7 @@ NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches """ +# TODO: Can we get rid of SAT_MEAN and SAT_STD? See issue #231 SAT_MEAN = xr.DataArray( data=[ 93.23458, diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index 4b6666e1..34ea2512 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional +import re +import os import fsspec.asyn import gcsfs import numpy as np @@ -12,8 +14,9 @@ import torch import xarray as xr +import nowcasting_dataset from nowcasting_dataset.consts import Array - +from nowcasting_dataset.config import load, model logger = logging.getLogger(__name__) @@ -178,3 +181,28 @@ def __enter__(self): def __exit__(self, type, value, traceback): """ Close temporary file """ self.temp_file.close() + + +def remove_regex_pattern_from_keys(d: dict, pattern_to_remove: str, **regex_compile_kwargs) -> dict: + """Remove `pattern_to_remove` from all keys in `d`. + + Return a new dict with the same values as `d`, but where the key names + have had `pattern_to_remove` removed. + """ + new_dict = {} + regex = re.compile(pattern_to_remove, **regex_compile_kwargs) + for old_key, value in d.items(): + new_key = regex.sub(string=old_key, repl="") + new_dict[new_key] = value + return new_dict + + +def get_config_with_test_paths(config_filename: str) -> model.Configuration: + """Sets the base paths to point to the testing data in this repository.""" + local_path = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "../") + + # load configuration, this can be changed to a different filename as needed + filename = os.path.join(local_path, "tests", "config", config_filename) + config = load.load_yaml_configuration(filename) + config.set_base_path(local_path) + return config diff --git a/scripts/prepare_ml_data.py b/scripts/prepare_ml_data.py index bdbdd7f6..86bfb1aa 100755 --- a/scripts/prepare_ml_data.py +++ b/scripts/prepare_ml_data.py @@ -53,8 +53,8 @@ config = set_git_commit(config) # Solar PV data -PV_DATA_FILENAME = config.input_data.pv.solar_pv_data_filename -PV_METADATA_FILENAME = config.input_data.pv.solar_pv_metadata_filename +PV_DATA_FILENAME = config.input_data.pv.pv_filename +PV_METADATA_FILENAME = config.input_data.pv.pv_metadata_filename # Satellite data SAT_ZARR_PATH = config.input_data.satellite.satellite_zarr_path diff --git a/tests/config/nwp_size_test.yaml b/tests/config/nwp_size_test.yaml index 50b9e4d9..176a08a5 100644 --- a/tests/config/nwp_size_test.yaml +++ b/tests/config/nwp_size_test.yaml @@ -11,10 +11,10 @@ input_data: nwp_image_size_pixels: 64 nwp_zarr_path: tests/data/nwp_data/test.zarr pv: - solar_pv_data_filename: tests/data/pv_data/test.nc - solar_pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv + pv_filename: tests/data/pv_data/test.nc + pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv satellite: - sat_channels: + satellite_channels: - HRV satellite_image_size_pixels: 64 satellite_zarr_path: tests/data/sat_data.zarr diff --git a/tests/config/test.yaml b/tests/config/test.yaml index 105cfc00..6afee7ac 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -11,10 +11,10 @@ input_data: nwp_image_size_pixels: 2 nwp_zarr_path: tests/data/nwp_data/test.zarr pv: - solar_pv_data_filename: tests/data/pv_data/test.nc - solar_pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv + pv_filename: tests/data/pv_data/test.nc + pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv satellite: - sat_channels: + satellite_channels: - HRV satellite_image_size_pixels: 64 satellite_zarr_path: tests/data/sat_data.zarr diff --git a/tests/data_sources/gsp/test_gsp_data_source.py b/tests/data_sources/gsp/test_gsp_data_source.py index ae2dfb67..4dd23878 100644 --- a/tests/data_sources/gsp/test_gsp_data_source.py +++ b/tests/data_sources/gsp/test_gsp_data_source.py @@ -10,7 +10,7 @@ def test_gsp_pv_data_source_init(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( - filename=f"{local_path}/tests/data/gsp/test.zarr", + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", start_dt=datetime(2019, 1, 1), end_dt=datetime(2019, 1, 2), history_minutes=30, @@ -24,7 +24,7 @@ def test_gsp_pv_data_source_get_locations(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( - filename=f"{local_path}/tests/data/gsp/test.zarr", + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", start_dt=datetime(2019, 1, 1), end_dt=datetime(2019, 1, 2), history_minutes=30, @@ -52,7 +52,7 @@ def test_gsp_pv_data_source_get_example(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( - filename=f"{local_path}/tests/data/gsp/test.zarr", + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", start_dt=datetime(2019, 1, 1), end_dt=datetime(2019, 1, 2), history_minutes=30, @@ -76,7 +76,7 @@ def test_gsp_pv_data_source_get_batch(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( - filename=f"{local_path}/tests/data/gsp/test.zarr", + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", start_dt=datetime(2019, 1, 1), end_dt=datetime(2019, 1, 2), history_minutes=30, diff --git a/tests/data_sources/sun/test_load.py b/tests/data_sources/sun/test_load.py index 894b6bad..4a845269 100644 --- a/tests/data_sources/sun/test_load.py +++ b/tests/data_sources/sun/test_load.py @@ -43,14 +43,14 @@ def test_save(): ) with tempfile.TemporaryDirectory() as fp: - save_to_zarr(azimuth=azimuth, elevation=elevation, filename=fp) + save_to_zarr(azimuth=azimuth, elevation=elevation, zarr_path=fp) def test_load(test_data_folder): - filename = test_data_folder + "/sun/test.zarr" + zarr_path = test_data_folder + "/sun/test.zarr" - azimuth, elevation = load_from_zarr(filename=filename) + azimuth, elevation = load_from_zarr(zarr_path=zarr_path) assert type(azimuth) == pd.DataFrame assert type(elevation) == pd.DataFrame diff --git a/tests/data_sources/sun/test_sun_data_source.py b/tests/data_sources/sun/test_sun_data_source.py index 1f5a417d..e84db918 100644 --- a/tests/data_sources/sun/test_sun_data_source.py +++ b/tests/data_sources/sun/test_sun_data_source.py @@ -3,15 +3,15 @@ def test_init(test_data_folder): - filename = test_data_folder + "/sun/test.zarr" + zarr_path = test_data_folder + "/sun/test.zarr" - _ = SunDataSource(filename=filename, history_minutes=30, forecast_minutes=60) + _ = SunDataSource(zarr_path=zarr_path, history_minutes=30, forecast_minutes=60) def test_get_example(test_data_folder): - filename = test_data_folder + "/sun/test.zarr" + zarr_path = test_data_folder + "/sun/test.zarr" - sun_data_source = SunDataSource(filename=filename, history_minutes=30, forecast_minutes=60) + sun_data_source = SunDataSource(zarr_path=zarr_path, history_minutes=30, forecast_minutes=60) x = 256895.63164759654 y = 666180.3018829626 @@ -24,9 +24,9 @@ def test_get_example(test_data_folder): def test_get_example_different_year(test_data_folder): - filename = test_data_folder + "/sun/test.zarr" + zarr_path = test_data_folder + "/sun/test.zarr" - sun_data_source = SunDataSource(filename=filename, history_minutes=30, forecast_minutes=60) + sun_data_source = SunDataSource(zarr_path=zarr_path, history_minutes=30, forecast_minutes=60) x = 256895.63164759654 y = 666180.3018829626 diff --git a/tests/data_sources/test_data_source_list.py b/tests/data_sources/test_data_source_list.py index 089a4b79..f3f8be58 100644 --- a/tests/data_sources/test_data_source_list.py +++ b/tests/data_sources/test_data_source_list.py @@ -3,13 +3,14 @@ import os from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.data_source_list import DataSourceList +import nowcasting_dataset.utils as nd_utils def test_sample_spatial_and_temporal_locations_for_examples(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( - filename=f"{local_path}/tests/data/gsp/test.zarr", + zarr_path=f"{local_path}/tests/data/gsp/test.zarr", start_dt=datetime(2019, 1, 1), end_dt=datetime(2019, 1, 2), history_minutes=30, @@ -26,3 +27,12 @@ def test_sample_spatial_and_temporal_locations_for_examples(): assert locations.columns.to_list() == ["t0_datetime_UTC", "x_center_OSGB", "y_center_OSGB"] assert len(locations) == 10 + + +def test_from_config(): + config = nd_utils.get_config_with_test_paths("test.yaml") + data_source_list = DataSourceList.from_config(config.input_data) + assert len(data_source_list) == 6 + assert isinstance( + data_source_list.data_source_which_defines_geospatial_locations, GSPDataSource + ) diff --git a/tests/data_sources/test_datasource_output.py b/tests/data_sources/test_datasource_output.py index 66db06fb..1bc21b80 100644 --- a/tests/data_sources/test_datasource_output.py +++ b/tests/data_sources/test_datasource_output.py @@ -51,7 +51,7 @@ def test_pv(): def test_satellite(): s = satellite_fake( - batch_size=4, seq_length_5=13, satellite_image_size_pixels=64, number_sat_channels=7 + batch_size=4, seq_length_5=13, satellite_image_size_pixels=64, number_satellite_channels=7 ) assert s.x is not None diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index 631b6857..2a9fcc70 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -8,12 +8,12 @@ PATH = os.path.dirname(nowcasting_dataset.__file__) # Solar PV data (test data) -NWP_FILENAME = f"{PATH}/../tests/data/nwp_data/test.zarr" +NWP_ZARR_PATH = f"{PATH}/../tests/data/nwp_data/test.zarr" def test_nwp_data_source_init(): _ = NWPDataSource( - filename=NWP_FILENAME, + zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, n_timesteps_per_batch=8, @@ -22,7 +22,7 @@ def test_nwp_data_source_init(): def test_nwp_data_source_open(): nwp = NWPDataSource( - filename=NWP_FILENAME, + zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, n_timesteps_per_batch=8, @@ -34,7 +34,7 @@ def test_nwp_data_source_open(): def test_nwp_data_source_batch(): nwp = NWPDataSource( - filename=NWP_FILENAME, + zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, n_timesteps_per_batch=8, @@ -54,7 +54,7 @@ def test_nwp_data_source_batch(): def test_nwp_get_contiguous_time_periods(): nwp = NWPDataSource( - filename=NWP_FILENAME, + zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, n_timesteps_per_batch=8, @@ -70,7 +70,7 @@ def test_nwp_get_contiguous_time_periods(): def test_nwp_get_contiguous_t0_time_periods(): nwp = NWPDataSource( - filename=NWP_FILENAME, + zarr_path=NWP_ZARR_PATH, history_minutes=30, forecast_minutes=60, n_timesteps_per_batch=8, diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 73682051..f73ba096 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -6,13 +6,11 @@ import pandas as pd import pytest -import nowcasting_dataset -from nowcasting_dataset.config.load import load_yaml_configuration - from nowcasting_dataset.dataset import datamodule from nowcasting_dataset.dataset.datamodule import NowcastingDataModule from nowcasting_dataset.dataset.split.split import SplitMethod from nowcasting_dataset.dataset.batch import Batch +import nowcasting_dataset.utils as nd_utils logging.basicConfig(format="%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s") _LOG = logging.getLogger("nowcasting_dataset") @@ -56,22 +54,11 @@ def test_setup(nowcasting_datamodule: datamodule.NowcastingDataModule): nowcasting_datamodule.setup() -def _get_config_with_test_paths(config_filename: str): - """Sets the base paths to point to the testing data in this repository.""" - local_path = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "../") - - # load configuration, this can be changed to a different filename as needed - filename = os.path.join(local_path, "tests", "config", config_filename) - config = load_yaml_configuration(filename) - config.set_base_path(local_path) - return config - - @pytest.mark.parametrize("config_filename", ["test.yaml", "nwp_size_test.yaml"]) def test_data_module(config_filename): # load configuration, this can be changed to a different filename as needed - config = _get_config_with_test_paths(config_filename) + config = nd_utils.get_config_with_test_paths(config_filename) data_module = NowcastingDataModule( batch_size=config.process.batch_size, @@ -80,9 +67,9 @@ def test_data_module(config_filename): satellite_image_size_pixels=config.input_data.satellite.satellite_image_size_pixels, nwp_image_size_pixels=config.input_data.nwp.nwp_image_size_pixels, nwp_channels=config.input_data.nwp.nwp_channels[0:1], - sat_channels=config.input_data.satellite.sat_channels, # reduced for test data - pv_power_filename=config.input_data.pv.solar_pv_data_filename, - pv_metadata_filename=config.input_data.pv.solar_pv_metadata_filename, + sat_channels=config.input_data.satellite.satellite_channels, # reduced for test data + pv_power_filename=config.input_data.pv.pv_filename, + pv_metadata_filename=config.input_data.pv.pv_metadata_filename, sat_filename=config.input_data.satellite.satellite_zarr_path, nwp_base_path=config.input_data.nwp.nwp_zarr_path, gsp_filename=config.input_data.gsp.gsp_zarr_path, @@ -124,7 +111,7 @@ def test_data_module(config_filename): def test_batch_to_batch_to_dataset(): - config = _get_config_with_test_paths("test.yaml") + config = nd_utils.get_config_with_test_paths("test.yaml") data_module = NowcastingDataModule( batch_size=config.process.batch_size, @@ -133,9 +120,9 @@ def test_batch_to_batch_to_dataset(): satellite_image_size_pixels=config.input_data.satellite.satellite_image_size_pixels, nwp_image_size_pixels=config.input_data.nwp.nwp_image_size_pixels, nwp_channels=config.input_data.nwp.nwp_channels[0:1], - sat_channels=config.input_data.satellite.sat_channels, # reduced for test data - pv_power_filename=config.input_data.pv.solar_pv_data_filename, - pv_metadata_filename=config.input_data.pv.solar_pv_metadata_filename, + sat_channels=config.input_data.satellite.satellite_channels, # reduced for test data + pv_power_filename=config.input_data.pv.pv_filename, + pv_metadata_filename=config.input_data.pv.pv_metadata_filename, sat_filename=config.input_data.satellite.satellite_zarr_path, nwp_base_path=config.input_data.nwp.nwp_zarr_path, gsp_filename=config.input_data.gsp.gsp_zarr_path, diff --git a/tests/test_utils.py b/tests/test_utils.py index 994d49da..eb3a9564 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -32,3 +32,20 @@ def test_sin_and_cos(): def test_get_netcdf_filename(): assert utils.get_netcdf_filename(10) == "10.nc" assert utils.get_netcdf_filename(10, add_hash=True) == "77eb6f_10.nc" + + +def test_remove_regex_pattern_from_keys(): + d = { + "satellite_zarr_path": "/a/b/c/foo.zarr", + "bar": "baz", + "satellite_channels": ["HRV"], + "n_satellite_per_batch": 4, + } + correct = { + "zarr_path": "/a/b/c/foo.zarr", + "bar": "baz", + "channels": ["HRV"], + "n_satellite_per_batch": 4, + } + new_dict = utils.remove_regex_pattern_from_keys(d, pattern_to_remove=r"^satellite_") + assert new_dict == correct