Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 37b5b9c

Browse files
Merge pull request #245 from openclimatefix/issue/244-mulit-process
speed up loading, and speed up tests
2 parents 33128b6 + e392332 commit 37b5b9c

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

nowcasting_dataset/dataset/batch.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from pathlib import Path
77
from typing import Optional, Union
8+
from concurrent import futures
89

910
import xarray as xr
1011
from pydantic import BaseModel, Field
@@ -130,26 +131,44 @@ def save_netcdf(self, batch_i: int, path: Path):
130131
path: the path where it will be saved. This can be local or in the cloud.
131132
132133
"""
133-
for data_source in self.data_sources:
134-
if data_source is not None:
135-
data_source.save_netcdf(batch_i=batch_i, path=path)
134+
135+
with futures.ThreadPoolExecutor() as executor:
136+
# Submit tasks to the executor.
137+
for data_source in self.data_sources:
138+
if data_source is not None:
139+
_ = executor.submit(
140+
data_source.save_netcdf,
141+
batch_i=batch_i,
142+
path=path,
143+
)
136144

137145
@staticmethod
138146
def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int):
139147
"""Load batch from netcdf file"""
140148
data_sources_names = Example.__fields__.keys()
141149

142-
# collect data sources
150+
# set up futures executor
143151
batch_dict = {}
144-
for data_source_name in data_sources_names:
145-
146-
local_netcdf_filename = os.path.join(
147-
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
148-
)
149-
if os.path.exists(local_netcdf_filename):
150-
xr_dataset = xr.load_dataset(local_netcdf_filename)
151-
else:
152-
xr_dataset = None
152+
with futures.ThreadPoolExecutor() as executor:
153+
future_examples_per_source = []
154+
155+
# loop over data sources
156+
for data_source_name in data_sources_names:
157+
158+
local_netcdf_filename = os.path.join(
159+
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
160+
)
161+
162+
# submit task
163+
future_examples = executor.submit(
164+
xr.load_dataset,
165+
filename_or_obj=local_netcdf_filename,
166+
)
167+
future_examples_per_source.append([data_source_name, future_examples])
168+
169+
# Collect results from each thread.
170+
for data_source_name, future_examples in future_examples_per_source:
171+
xr_dataset = future_examples.result()
153172

154173
batch_dict[data_source_name] = xr_dataset
155174

tests/dataset/test_batch.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,30 @@
1010

1111
def test_model():
1212

13-
_ = Batch.fake()
13+
con = Configuration()
14+
con.process.batch_size = 4
15+
16+
_ = Batch.fake(configuration=con)
1417

1518

1619
def test_model_save_to_netcdf():
1720

21+
con = Configuration()
22+
con.process.batch_size = 4
23+
1824
with tempfile.TemporaryDirectory() as dirpath:
19-
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
25+
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
2026

2127
assert os.path.exists(f"{dirpath}/satellite/0.nc")
2228

2329

2430
def test_model_load_from_netcdf():
2531

32+
con = Configuration()
33+
con.process.batch_size = 4
34+
2635
with tempfile.TemporaryDirectory() as dirpath:
27-
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
36+
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
2837

2938
batch = Batch.load_netcdf(batch_idx=0, local_netcdf_path=dirpath)
3039

@@ -33,7 +42,10 @@ def test_model_load_from_netcdf():
3342

3443
def test_batch_to_batch_ml():
3544

36-
_ = BatchML.from_batch(batch=Batch.fake())
45+
con = Configuration()
46+
con.process.batch_size = 4
47+
48+
_ = BatchML.from_batch(batch=Batch.fake(configuration=con))
3749

3850

3951
def test_fake_dataset():

0 commit comments

Comments
 (0)