Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
27 changes: 25 additions & 2 deletions nowcasting_dataset/config/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pydantic import BaseModel, Field, validator

from pydantic import BaseModel, Field
from typing import Optional
from nowcasting_dataset.data_sources.nwp_data_source import NWP_VARIABLE_NAMES
from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES


from datetime import datetime
import git

Expand Down Expand Up @@ -52,8 +55,10 @@ class OutputData(BaseModel):
class Process(BaseModel):
seed: int = Field(1234, description="Random seed, so experiments can be repeatable")
batch_size: int = Field(32, description="the batch size of the data")
forecast_minutes: int = Field(60, description="how many minutes to forecast in the future")
history_minutes: int = Field(30, description="how many historic minutes are used")
forecast_minutes: int = Field(
60, ge=0, description="how many minutes to forecast in the future"
)
history_minutes: int = Field(30, ge=0, description="how many historic minutes are used")
satellite_image_size_pixels: int = Field(64, description="the size of the satellite images")
nwp_image_size_pixels: int = Field(2, description="the size of the nwp images")

Expand All @@ -65,6 +70,24 @@ class Process(BaseModel):
precision: int = Field(16, description="what precision to use")
val_check_interval: int = Field(1000, description="TODO")

@property
def seq_len_30_minutes(self):
return int((self.history_minutes + self.forecast_minutes) / 30 + 1)

@property
def seq_len_5_minutes(self):
return int((self.history_minutes + self.forecast_minutes) / 5 + 1)

@validator("history_minutes")
def history_minutes_divide_by_30(cls, v):
assert v % 30 == 0 # this means it also divides by 5
return v

@validator("forecast_minutes")
def forecast_minutes_divide_by_30(cls, v):
assert v % 30 == 0 # this means it also divides by 5
return v


class Configuration(BaseModel):

Expand Down
36 changes: 36 additions & 0 deletions nowcasting_dataset/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,40 @@
NWP_Y_COORDS = "nwp_y_coords"
X_METERS_CENTER = "x_meters_center"
Y_METERS_CENTER = "y_meters_center"
NWP_VARIABLE_NAMES = ("t", "dswrf", "prate", "r", "sde", "si10", "vis", "lcc", "mcc", "hcc")
SAT_VARIABLE_NAMES = (
"HRV",
"IR_016",
"IR_039",
"IR_087",
"IR_097",
"IR_108",
"IR_120",
"IR_134",
"VIS006",
"VIS008",
"WV_062",
"WV_073",
)

DEFAULT_REQUIRED_KEYS = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

NWP_DATA,
NWP_X_COORDS,
NWP_Y_COORDS,
SATELLITE_DATA,
SATELLITE_X_COORDS,
SATELLITE_Y_COORDS,
PV_YIELD,
PV_SYSTEM_ID,
PV_SYSTEM_ROW_NUMBER,
PV_SYSTEM_X_COORDS,
PV_SYSTEM_Y_COORDS,
X_METERS_CENTER,
Y_METERS_CENTER,
GSP_ID,
GSP_YIELD,
GSP_X_COORDS,
GSP_Y_COORDS,
GSP_DATETIME_INDEX,
] + list(DATETIME_FEATURE_NAMES)
T0_DT = "t0_dt"
4 changes: 2 additions & 2 deletions nowcasting_dataset/data_sources/nwp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from numbers import Number
from concurrent import futures

_LOG = logging.getLogger("nowcasting_dataset")
_LOG = logging.getLogger(__name__)

NWP_VARIABLE_NAMES = ("t", "dswrf", "prate", "r", "sde", "si10", "vis", "lcc", "mcc", "hcc")
from nowcasting_dataset.consts import NWP_VARIABLE_NAMES

# Means computed with
# nwp_ds = NWPDataSource(...)
Expand Down
16 changes: 2 additions & 14 deletions nowcasting_dataset/data_sources/satellite_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,8 @@
_LOG = logging.getLogger("nowcasting_dataset")


SAT_VARIABLE_NAMES = (
"HRV",
"IR_016",
"IR_039",
"IR_087",
"IR_097",
"IR_108",
"IR_120",
"IR_134",
"VIS006",
"VIS008",
"WV_062",
"WV_073",
)
from nowcasting_dataset.consts import SAT_VARIABLE_NAMES

# Means computed with
# nwp_ds = NWPDataSource(...)
# nwp_ds.open()
Expand Down
4 changes: 4 additions & 0 deletions nowcasting_dataset/dataset/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches
Main thing in here is a Typed Dictionary. This is used to store one element of data use for one step in the ML models.
There is also a validation function. See this file for documentation about exactly what data is available in each ML
training Example.

## validatey.py

Contains a class that can validate the prepare ml dataset
74 changes: 19 additions & 55 deletions nowcasting_dataset/dataset/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import datetime
import pandas as pd
from numbers import Number
from typing import List, Tuple, Iterable, Callable, Union, Optional
import nowcasting_dataset.consts
from typing import List, Tuple, Callable, Union, Optional
from nowcasting_dataset import data_sources
from dataclasses import dataclass
from concurrent import futures
import logging
import gcsfs
import boto3
import os
Expand All @@ -22,34 +19,24 @@
from nowcasting_dataset.cloud.aws import aws_download_to_local

from nowcasting_dataset.consts import (
GSP_ID,
GSP_YIELD,
GSP_X_COORDS,
GSP_Y_COORDS,
GSP_DATETIME_INDEX,
SATELLITE_X_COORDS,
SATELLITE_Y_COORDS,
SATELLITE_DATA,
NWP_DATA,
NWP_X_COORDS,
NWP_Y_COORDS,
PV_SYSTEM_X_COORDS,
PV_SYSTEM_Y_COORDS,
PV_YIELD,
PV_AZIMUTH_ANGLE,
PV_ELEVATION_ANGLE,
PV_SYSTEM_ID,
PV_SYSTEM_ROW_NUMBER,
Y_METERS_CENTER,
X_METERS_CENTER,
SATELLITE_DATETIME_INDEX,
NWP_TARGET_TIME,
PV_DATETIME_INDEX,
DATETIME_FEATURE_NAMES,
DEFAULT_REQUIRED_KEYS,
T0_DT,
)
from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES
import logging

logger = logging.getLogger(__name__)

"""
This file contains the following classes
Expand Down Expand Up @@ -110,27 +97,7 @@ def __init__(
tmp_path: str,
configuration: Configuration,
cloud: str = "gcp",
required_keys: Union[Tuple[str], List[str]] = [
NWP_DATA,
NWP_X_COORDS,
NWP_Y_COORDS,
SATELLITE_DATA,
SATELLITE_X_COORDS,
SATELLITE_Y_COORDS,
PV_YIELD,
PV_SYSTEM_ID,
PV_SYSTEM_ROW_NUMBER,
PV_SYSTEM_X_COORDS,
PV_SYSTEM_Y_COORDS,
X_METERS_CENTER,
Y_METERS_CENTER,
GSP_ID,
GSP_YIELD,
GSP_X_COORDS,
GSP_Y_COORDS,
GSP_DATETIME_INDEX,
]
+ list(DATETIME_FEATURE_NAMES),
required_keys: Union[Tuple[str], List[str]] = None,
history_minutes: Optional[int] = None,
forecast_minutes: Optional[int] = None,
):
Expand All @@ -153,7 +120,6 @@ def __init__(
self.src_path = src_path
self.tmp_path = tmp_path
self.cloud = cloud
self.required_keys = list(required_keys)
self.history_minutes = history_minutes
self.forecast_minutes = forecast_minutes
self.configuration = configuration
Expand All @@ -174,6 +140,10 @@ def __init__(
# Index into either sat_datetime_index or nwp_target_time indicating the current time,
self.current_timestep_5_index = int(configuration.process.history_minutes // 5) + 1

if required_keys is None:
required_keys = DEFAULT_REQUIRED_KEYS
self.required_keys = list(required_keys)

# setup cloud connections as None
self.gcs = None
self.s3_resource = None
Expand Down Expand Up @@ -203,6 +173,7 @@ def __getitem__(self, batch_idx: int) -> example.Example:
NamedDict where each value is a numpy array. The size of this
array's first dimension is the batch size.
"""
logger.debug(f"Getting batch {batch_idx}")
if not 0 <= batch_idx < self.n_batches:
raise IndexError(
"batch_idx must be in the range" f" [0, {self.n_batches}), not {batch_idx}!"
Expand Down Expand Up @@ -230,22 +201,15 @@ def __getitem__(self, batch_idx: int) -> example.Example:
if self.cloud != "local":
os.remove(local_netcdf_filename)

batch = example.Example(
sat_datetime_index=netcdf_batch.sat_time_coords,
nwp_target_time=netcdf_batch.nwp_time_coords,
)
for key in self.required_keys:
try:
batch[key] = netcdf_batch[key]
except KeyError:
pass

sat_data = batch[SATELLITE_DATA]
if sat_data.dtype == np.int16:
sat_data = sat_data.astype(np.float32)
sat_data = sat_data - SAT_MEAN
sat_data /= SAT_STD
batch[SATELLITE_DATA] = sat_data
batch = example.xr_to_example(batch_xr=netcdf_batch, required_keys=self.required_keys)

if SATELLITE_DATA in self.required_keys:
sat_data = batch[SATELLITE_DATA]
if sat_data.dtype == np.int16:
sat_data = sat_data.astype(np.float32)
sat_data = sat_data - SAT_MEAN
sat_data /= SAT_STD
batch[SATELLITE_DATA] = sat_data

if self.select_subset_data:
batch = subselect_data(
Expand Down
103 changes: 30 additions & 73 deletions nowcasting_dataset/dataset/example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TypedDict
from typing import TypedDict, List
import pandas as pd

from nowcasting_dataset.consts import *
import numpy as np
from numbers import Number
Expand Down Expand Up @@ -77,16 +78,41 @@ class Example(TypedDict):
# : Includes central GSP, which will always be the first entry. This will be a numpy array of values.
gsp_yield: Array #: shape = [batch_size, ] seq_length, n_gsp_systems_per_example
# GSP identification.
gsp_id: Array #: shape = [batch_size, ] n_pv_systems_per_example
gsp_id: Array #: shape = [batch_size, ] n_gsp_per_example
#: GSP geographical location (in OSGB coords).
gsp_x_coords: Array #: shape = [batch_size, ] n_pv_systems_per_example
gsp_y_coords: Array #: shape = [batch_size, ] n_pv_systems_per_example
gsp_x_coords: Array #: shape = [batch_size, ] n_gsp_per_example
gsp_y_coords: Array #: shape = [batch_size, ] n_gsp_per_example
gsp_datetime_index: Array #: shape = [batch_size, ] seq_length

# if the centroid type is a GSP, or a PV system
object_at_center: str #: shape = [batch_size, ]


def xr_to_example(batch_xr: xr.core.dataset.Dataset, required_keys: List[str]) -> Example:
"""
Change xr dataset to Example

Args:
batch_xr: batch data in xarray format
required_keys: the keys that are need

Returns: Example object of the xarray data

"""

batch = Example(
sat_datetime_index=batch_xr.sat_time_coords,
nwp_target_time=batch_xr.nwp_time_coords,
)
for key in required_keys:
try:
batch[key] = batch_xr[key]
except KeyError:
pass

return batch


def to_numpy(example: Example) -> Example:
for key, value in example.items():
if isinstance(value, xr.DataArray):
Expand All @@ -104,72 +130,3 @@ def to_numpy(example: Example) -> Example:

example[key] = value
return example


def validate_example(
data: Example,
seq_len_30_minutes: int,
seq_len_5_minutes: int,
sat_image_size: int = 64,
n_sat_channels: int = 1,
nwp_image_size: int = 0,
n_nwp_channels: int = 1,
n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE,
n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE,
):
"""
Validate the size and shape of the data
Args:
data: Typed dictionary of the data
seq_len_30_minutes: the length of the sequence for 30 minutely data
seq_len_5_minutes: the length of the sequence for 5 minutely data
sat_image_size: the satellite image size
n_sat_channels: the number of satellite channgles
nwp_image_size: the nwp image size
n_nwp_channels: the number of nwp channels
n_pv_systems_per_example: the number pv systems with nan padding
n_gsp_per_example: the number gsp systems with nan padding
"""

assert len(data[GSP_ID]) == n_gsp_per_example
n_gsp_system_id = len(data[GSP_ID])
assert data[GSP_YIELD].shape == (seq_len_30_minutes, n_gsp_system_id)
assert len(data[GSP_X_COORDS]) == n_gsp_system_id
assert len(data[GSP_Y_COORDS]) == n_gsp_system_id
assert len(data[GSP_DATETIME_INDEX]) == seq_len_30_minutes

assert data[OBJECT_AT_CENTER] == "gsp"
assert type(data["x_meters_center"]) == np.float64
assert type(data["y_meters_center"]) == np.float64

n_pv_systems = len(data[PV_SYSTEM_ID][~np.isnan(data[PV_SYSTEM_ID])])

assert len(data[PV_SYSTEM_ID]) == n_pv_systems_per_example
assert data[PV_YIELD].shape == (seq_len_5_minutes, n_pv_systems_per_example)
assert len(data[PV_SYSTEM_X_COORDS]) == n_pv_systems_per_example
assert len(data[PV_SYSTEM_Y_COORDS]) == n_pv_systems_per_example
assert len(data[PV_SYSTEM_ROW_NUMBER][~np.isnan(data[PV_SYSTEM_ROW_NUMBER])]) == n_pv_systems
assert len(data[PV_SYSTEM_ROW_NUMBER][~np.isnan(data[PV_SYSTEM_ROW_NUMBER])]) == n_pv_systems

if PV_AZIMUTH_ANGLE in data.keys():
assert data[PV_AZIMUTH_ANGLE].shape == (seq_len_5_minutes, n_pv_systems_per_example)
if PV_AZIMUTH_ANGLE in data.keys():
assert data[PV_ELEVATION_ANGLE].shape == (seq_len_5_minutes, n_pv_systems_per_example)

assert data["sat_data"].shape == (
seq_len_5_minutes,
sat_image_size,
sat_image_size,
n_sat_channels,
)
assert len(data["sat_x_coords"]) == sat_image_size
assert len(data["sat_y_coords"]) == sat_image_size
assert len(data["sat_datetime_index"]) == seq_len_5_minutes

assert data["nwp"].shape == (n_nwp_channels, seq_len_5_minutes, nwp_image_size, nwp_image_size)
assert len(data["nwp_x_coords"]) == nwp_image_size
assert len(data["nwp_y_coords"]) == nwp_image_size
assert len(data["nwp_target_time"]) == seq_len_5_minutes

for feature in DATETIME_FEATURE_NAMES:
assert len(data[feature]) == seq_len_5_minutes
Loading