Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 80b1df6

Browse files
vreisfacebook-github-bot
authored andcommitted
Move use_gpu from ClassyTrainer to ClassificationTask
Summary: This is the first in a series of diffs to eliminate the ClassyTrainer abstraction. The only reason Trainer existed was to support elastic training, but PET v0.2 does not require changing out training loop. The plan is to move all attributes from ClassyTrainer into ClassificationTask. Start by moving use_gpu to the task. Differential Revision: D20801017 fbshipit-source-id: 9fd3322a4503498a969c2bdfa7301c8c99a8f790
1 parent 6214d10 commit 80b1df6

12 files changed

+69
-149
lines changed

classy_train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,13 @@ def main(args, config):
9393
# Configure hooks to do tensorboard logging, checkpoints and so on
9494
task.set_hooks(configure_hooks(args, config))
9595

96-
use_gpu = None
97-
if args.device is not None:
98-
use_gpu = args.device == "gpu"
99-
assert torch.cuda.is_available() or not use_gpu, "CUDA is unavailable"
100-
10196
# LocalTrainer is used for a single node. DistributedTrainer will setup
10297
# training to use PyTorch's DistributedDataParallel.
10398
trainer_class = {"none": LocalTrainer, "ddp": DistributedTrainer}[
10499
args.distributed_backend
105100
]
106101

107-
trainer = trainer_class(use_gpu=use_gpu, num_dataloader_workers=args.num_workers)
102+
trainer = trainer_class(num_dataloader_workers=args.num_workers)
108103

109104
logging.info(
110105
f"Starting training on rank {get_rank()} worker. "

classy_vision/generic/opts.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ def add_generic_args(parser):
1818
parser.add_argument(
1919
"--config_file", type=str, help="path to config file for model", required=True
2020
)
21-
parser.add_argument(
22-
"--device",
23-
default=None,
24-
type=str,
25-
help="device to use: either 'cpu' or 'gpu'. If unspecified, will use GPU when available and CPU otherwise.",
26-
)
2721
parser.add_argument(
2822
"--num_workers",
2923
default=4,
@@ -145,13 +139,6 @@ def check_generic_args(args):
145139
# check types and values:
146140
assert is_pos_int(args.num_workers), "incorrect number of workers"
147141
assert is_pos_int(args.visdom_port), "incorrect visdom port"
148-
assert (
149-
args.device is None or args.device == "cpu" or args.device == "gpu"
150-
), "unknown device"
151-
152-
# check that CUDA is available:
153-
if args.device == "gpu":
154-
assert torch.cuda.is_available(), "CUDA required to train on GPUs"
155142

156143
# create checkpoint folder if it does not exist:
157144
if args.checkpoint_folder != "" and not os.path.exists(args.checkpoint_folder):

classy_vision/tasks/classification_task.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,16 @@ def __init__(self):
142142
self.perf_log = []
143143
self.last_batch = None
144144
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
145+
self.use_gpu = torch.cuda.is_available()
146+
147+
def set_use_gpu(self, use_gpu: bool):
148+
self.use_gpu = use_gpu
149+
150+
assert (
151+
not self.use_gpu or torch.cuda.is_available()
152+
), "CUDA required to train on GPUs"
153+
154+
return self
145155

146156
def set_checkpoint(self, checkpoint):
147157
"""Sets checkpoint on task.
@@ -359,6 +369,10 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
359369
.set_hooks(hooks)
360370
)
361371

372+
use_gpu = config.get("use_gpu")
373+
if use_gpu is not None:
374+
task.set_use_gpu(use_gpu)
375+
362376
for phase_type in phase_types:
363377
task.set_dataset(datasets[phase_type], phase_type)
364378

@@ -508,24 +522,19 @@ def build_dataloaders(
508522
for phase_type in self.datasets.keys()
509523
}
510524

511-
def prepare(
512-
self,
513-
num_dataloader_workers=0,
514-
pin_memory=False,
515-
use_gpu=False,
516-
dataloader_mp_context=None,
517-
):
525+
def prepare(self, num_dataloader_workers=0, dataloader_mp_context=None):
518526
"""Prepares task for training, populates all derived attributes
519527
520528
Args:
521529
num_dataloader_workers: Number of dataloading processes. If 0,
522530
dataloading is done on main process
523-
pin_memory: if true pin memory on GPU
524-
use_gpu: if true, load model, optimizer, loss, etc on GPU
525531
dataloader_mp_context: Determines how processes are spawned.
526532
Value must be one of None, "spawn", "fork", "forkserver".
527533
If None, then context is inherited from parent process
528534
"""
535+
536+
pin_memory = self.use_gpu and torch.cuda.device_count() > 1
537+
529538
self.phases = self._build_phases()
530539
self.dataloaders = self.build_dataloaders(
531540
num_workers=num_dataloader_workers,
@@ -539,7 +548,7 @@ def prepare(
539548
self.base_model = apex.parallel.convert_syncbn_model(self.base_model)
540549

541550
# move the model and loss to the right device
542-
if use_gpu:
551+
if self.use_gpu:
543552
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
544553
else:
545554
self.loss.cpu()
@@ -686,7 +695,7 @@ def set_classy_state(self, state):
686695
# Set up pytorch module in train vs eval mode, update optimizer.
687696
self._set_model_train_mode()
688697

689-
def eval_step(self, use_gpu):
698+
def eval_step(self):
690699
self.last_batch = None
691700

692701
# Process next sample
@@ -699,7 +708,7 @@ def eval_step(self, use_gpu):
699708

700709
# Copy sample to GPU
701710
target = sample["target"]
702-
if use_gpu:
711+
if self.use_gpu:
703712
for key, value in sample.items():
704713
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)
705714

@@ -726,12 +735,8 @@ def check_inf_nan(self, loss):
726735
if loss == float("inf") or loss == float("-inf") or loss != loss:
727736
raise FloatingPointError(f"Loss is infinity or NaN: {loss}")
728737

729-
def train_step(self, use_gpu):
730-
"""Train step to be executed in train loop
731-
732-
Args:
733-
use_gpu: if true, execute training on GPU
734-
"""
738+
def train_step(self):
739+
"""Train step to be executed in train loop."""
735740

736741
self.last_batch = None
737742

@@ -745,7 +750,7 @@ def train_step(self, use_gpu):
745750

746751
# Copy sample to GPU
747752
target = sample["target"]
748-
if use_gpu:
753+
if self.use_gpu:
749754
for key, value in sample.items():
750755
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)
751756

classy_vision/tasks/classy_task.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,7 @@ def set_classy_state(self, state):
8686

8787
@abstractmethod
8888
def prepare(
89-
self,
90-
num_dataloader_workers=0,
91-
pin_memory=False,
92-
use_gpu=False,
93-
dataloader_mp_context=None,
89+
self, num_dataloader_workers=0, pin_memory=False, dataloader_mp_context=None
9490
) -> None:
9591
"""
9692
Prepares the task for training.
@@ -102,19 +98,15 @@ def prepare(
10298
num_dataloader_workers: Number of workers to create for the dataloaders
10399
pin_memory: Whether the dataloaders should copy the Tensors into CUDA
104100
pinned memory (default False)
105-
use_gpu: True if training on GPUs, False otherwise
106101
"""
107102
pass
108103

109104
@abstractmethod
110-
def train_step(self, use_gpu) -> None:
105+
def train_step(self) -> None:
111106
"""
112107
Run a train step.
113108
114109
This corresponds to training over one batch of data from the dataloaders.
115-
116-
Args:
117-
use_gpu: True if training on GPUs, False otherwise
118110
"""
119111
pass
120112

@@ -155,24 +147,21 @@ def on_end(self):
155147
pass
156148

157149
@abstractmethod
158-
def eval_step(self, use_gpu) -> None:
150+
def eval_step(self) -> None:
159151
"""
160152
Run an evaluation step.
161153
162154
This corresponds to evaluating the model over one batch of data.
163-
164-
Args:
165-
use_gpu: True if training on GPUs, False otherwise
166155
"""
167156
pass
168157

169-
def step(self, use_gpu) -> None:
158+
def step(self) -> None:
170159
from classy_vision.hooks import ClassyHookFunctions
171160

172161
if self.train:
173-
self.train_step(use_gpu)
162+
self.train_step()
174163
else:
175-
self.eval_step(use_gpu)
164+
self.eval_step()
176165

177166
for hook in self.hooks:
178167
hook.on_step(self)

classy_vision/tasks/fine_tuning_task.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,12 @@ def _set_model_train_mode(self):
6767
self.base_model.train(phase["train"])
6868

6969
def prepare(
70-
self,
71-
num_dataloader_workers: int = 0,
72-
pin_memory: bool = False,
73-
use_gpu: bool = False,
74-
dataloader_mp_context=None,
70+
self, num_dataloader_workers: int = 0, dataloader_mp_context=None
7571
) -> None:
7672
assert (
7773
self.pretrained_checkpoint is not None
7874
), "Need a pretrained checkpoint for fine tuning"
79-
super().prepare(
80-
num_dataloader_workers, pin_memory, use_gpu, dataloader_mp_context
81-
)
75+
super().prepare(num_dataloader_workers, dataloader_mp_context)
8276
if self.checkpoint is None:
8377
# no checkpoint exists, load the model's state from the pretrained
8478
# checkpoint

classy_vision/trainer/classy_trainer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,18 @@ class ClassyTrainer:
2727

2828
def __init__(
2929
self,
30-
use_gpu: Optional[bool] = None,
3130
num_dataloader_workers: int = 0,
3231
dataloader_mp_context: Optional[str] = None,
3332
):
3433
"""Constructor for ClassyTrainer.
3534
3635
Args:
37-
use_gpu: If true, then use GPUs for training.
38-
If None, then check if we have GPUs available, if we do
39-
then use GPU for training.
4036
num_dataloader_workers: Number of CPU processes doing dataloading
4137
per GPU. If 0, then dataloading is done on main thread.
4238
dataloader_mp_context: Determines how to launch
4339
new processes for dataloading. Must be one of "fork", "forkserver",
4440
"spawn". If None, process launching is inherited from parent.
4541
"""
46-
if use_gpu is None:
47-
use_gpu = torch.cuda.is_available()
48-
self.use_gpu = use_gpu
4942
self.num_dataloader_workers = num_dataloader_workers
5043
self.dataloader_mp_context = dataloader_mp_context
5144

@@ -57,11 +50,8 @@ def train(self, task: ClassyTask):
5750
everything that is needed for training
5851
"""
5952

60-
pin_memory = self.use_gpu and torch.cuda.device_count() > 1
6153
task.prepare(
6254
num_dataloader_workers=self.num_dataloader_workers,
63-
pin_memory=pin_memory,
64-
use_gpu=self.use_gpu,
6555
dataloader_mp_context=self.dataloader_mp_context,
6656
)
6757
assert isinstance(task, ClassyTask)
@@ -75,7 +65,7 @@ def train(self, task: ClassyTask):
7565
task.on_phase_start()
7666
while True:
7767
try:
78-
task.step(self.use_gpu)
68+
task.step()
7969
except StopIteration:
8070
break
8171
task.on_phase_end()

classy_vision/trainer/distributed_trainer.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,39 +56,19 @@ class DistributedTrainer(ClassyTrainer):
5656
"""Distributed trainer for using multiple training processes
5757
"""
5858

59-
def __init__(
60-
self,
61-
use_gpu: Optional[bool] = None,
62-
num_dataloader_workers: int = 0,
63-
dataloader_mp_context: Optional[str] = None,
64-
):
65-
"""Constructor for DistributedTrainer.
66-
67-
Args:
68-
use_gpu: If true, then use GPU 0 for training.
69-
If None, then check if we have GPUs available, if we do
70-
then use GPU for training.
71-
num_dataloader_workers: Number of CPU processes doing dataloading
72-
per GPU. If 0, then dataloading is done on main thread.
73-
dataloader_mp_context: Determines how to launch
74-
new processes for dataloading. Must be one of "fork", "forkserver",
75-
"spawn". If None, process launching is inherited from parent.
76-
"""
77-
super().__init__(
78-
use_gpu=use_gpu,
79-
num_dataloader_workers=num_dataloader_workers,
80-
dataloader_mp_context=dataloader_mp_context,
81-
)
59+
def train(self, task):
8260
_init_env_vars()
83-
_init_distributed(self.use_gpu)
61+
_init_distributed(task.use_gpu)
8462
logging.info(
8563
f"Done setting up distributed process_group with rank {get_rank()}"
8664
+ f", world_size {get_world_size()}"
8765
)
8866
local_rank = int(os.environ["LOCAL_RANK"])
89-
if self.use_gpu:
67+
if task.use_gpu:
9068
logging.info("Using GPU, CUDA device index: {}".format(local_rank))
9169
set_cuda_device_index(local_rank)
9270
else:
9371
logging.info("Using CPU")
9472
set_cpu_device()
73+
74+
super().train(task)

classy_vision/trainer/local_trainer.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,12 @@ class LocalTrainer(ClassyTrainer):
1616
"""Trainer to be used if you want want use only a single training process.
1717
"""
1818

19-
def __init__(
20-
self,
21-
use_gpu: Optional[bool] = None,
22-
num_dataloader_workers: int = 0,
23-
dataloader_mp_context: Optional[str] = None,
24-
):
25-
"""Constructor for LocalTrainer.
26-
27-
Args:
28-
use_gpu: If true, then use GPU 0 for training.
29-
If None, then check if we have GPUs available, if we do
30-
then use GPU for training.
31-
num_dataloader_workers: Number of CPU processes doing dataloading
32-
per GPU. If 0, then dataloading is done on main thread.
33-
dataloader_mp_context: Determines how to launch
34-
new processes for dataloading. Must be one of "fork", "forkserver",
35-
"spawn". If None, process launching is inherited from parent.
36-
"""
37-
super().__init__(
38-
use_gpu=use_gpu,
39-
num_dataloader_workers=num_dataloader_workers,
40-
dataloader_mp_context=dataloader_mp_context,
41-
)
42-
if self.use_gpu:
19+
def train(self, task):
20+
if task.use_gpu:
4321
logging.info("Using GPU, CUDA device index: {}".format(0))
4422
set_cuda_device_index(0)
4523
else:
4624
logging.info("Using CPU")
4725
set_cpu_device()
26+
27+
super().train(task)

test/generic_util_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def test_update_classy_state(self):
437437
task = build_task(config)
438438
task_2 = build_task(config)
439439
task_2.prepare()
440-
trainer = LocalTrainer(use_gpu=False)
440+
trainer = LocalTrainer()
441441
trainer.train(task)
442442
update_classy_state(task_2, task.get_classy_state(deep_copy=True))
443443
self._compare_states(task.get_classy_state(), task_2.get_classy_state())
@@ -449,13 +449,12 @@ def test_update_classy_model(self):
449449
"""
450450
config = get_fast_test_task_config()
451451
task = build_task(config)
452-
use_gpu = torch.cuda.is_available()
453-
trainer = LocalTrainer(use_gpu=use_gpu)
452+
trainer = LocalTrainer()
454453
trainer.train(task)
455454
for reset_heads in [False, True]:
456455
task_2 = build_task(config)
457456
# prepare task_2 for the right device
458-
task_2.prepare(use_gpu=use_gpu)
457+
task_2.prepare()
459458
update_classy_model(
460459
task_2.model, task.model.get_classy_state(deep_copy=True), reset_heads
461460
)

0 commit comments

Comments
 (0)