Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

speed up loading, and speed up tests #245

Merged
merged 2 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions nowcasting_dataset/dataset/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from pathlib import Path
from typing import Optional, Union
from concurrent import futures

import xarray as xr
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -130,26 +131,44 @@ def save_netcdf(self, batch_i: int, path: Path):
path: the path where it will be saved. This can be local or in the cloud.

"""
for data_source in self.data_sources:
if data_source is not None:
data_source.save_netcdf(batch_i=batch_i, path=path)

with futures.ThreadPoolExecutor() as executor:
# Submit tasks to the executor.
for data_source in self.data_sources:
if data_source is not None:
_ = executor.submit(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong but I think it's a good idea to keep these Future objects in a list, and then loop through the list of Future objects and wait on future.result() (just like you do in load_netcdf(). Mostly because I think that Exceptions from save_netcdf might get silently swallowed until you call future.result()? Or maybe Exceptions only get silently swallowed if using ProcessPoolExecutor, I'm not sure?!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that I'm not sure, have to look into that

data_source.save_netcdf,
batch_i=batch_i,
path=path,
)

@staticmethod
def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int):
"""Load batch from netcdf file"""
data_sources_names = Example.__fields__.keys()

# collect data sources
# set up futures executor
batch_dict = {}
for data_source_name in data_sources_names:

local_netcdf_filename = os.path.join(
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
)
if os.path.exists(local_netcdf_filename):
xr_dataset = xr.load_dataset(local_netcdf_filename)
else:
xr_dataset = None
with futures.ThreadPoolExecutor() as executor:
future_examples_per_source = []

# loop over data sources
for data_source_name in data_sources_names:

local_netcdf_filename = os.path.join(
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
)

# submit task
future_examples = executor.submit(
xr.load_dataset,
filename_or_obj=local_netcdf_filename,
)
future_examples_per_source.append([data_source_name, future_examples])

# Collect results from each thread.
for data_source_name, future_examples in future_examples_per_source:
xr_dataset = future_examples.result()

batch_dict[data_source_name] = xr_dataset

Expand Down
20 changes: 16 additions & 4 deletions tests/dataset/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,30 @@

def test_model():

_ = Batch.fake()
con = Configuration()
con.process.batch_size = 4

_ = Batch.fake(configuration=con)


def test_model_save_to_netcdf():

con = Configuration()
con.process.batch_size = 4

with tempfile.TemporaryDirectory() as dirpath:
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)

assert os.path.exists(f"{dirpath}/satellite/0.nc")


def test_model_load_from_netcdf():

con = Configuration()
con.process.batch_size = 4

with tempfile.TemporaryDirectory() as dirpath:
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)

batch = Batch.load_netcdf(batch_idx=0, local_netcdf_path=dirpath)

Expand All @@ -33,7 +42,10 @@ def test_model_load_from_netcdf():

def test_batch_to_batch_ml():

_ = BatchML.from_batch(batch=Batch.fake())
con = Configuration()
con.process.batch_size = 4

_ = BatchML.from_batch(batch=Batch.fake(configuration=con))


def test_fake_dataset():
Expand Down