Skip to content

Commit 1fddecc

Browse files
committed
aligned with partial original implementation
1 parent aa95139 commit 1fddecc

File tree

5 files changed

+280
-59
lines changed

5 files changed

+280
-59
lines changed

references/classification/presets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ def __init__(
1313
interpolation=InterpolationMode.BILINEAR,
1414
hflip_prob=0.5,
1515
auto_augment_policy=None,
16+
policy_magnitude=9,
1617
random_erase_prob=0.0,
18+
center_crop=False,
1719
):
18-
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
20+
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if center_crop else [transforms.CenterCrop(crop_size)]
1921
if hflip_prob > 0:
2022
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
2123
if auto_augment_policy is not None:
2224
if auto_augment_policy == "ra":
23-
trans.append(autoaugment.RandAugment(interpolation=interpolation))
25+
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=policy_magnitude))
2426
elif auto_augment_policy == "ta_wide":
2527
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
2628
elif auto_augment_policy == "augmix":
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import argparse
2+
import os
3+
import uuid
4+
from pathlib import Path
5+
6+
import train
7+
import submitit
8+
9+
10+
def parse_args():
11+
train_parser = train.get_args_parser(add_help=False)
12+
parser = argparse.ArgumentParser("Submitit for train", parents=[train_parser], add_help=True)
13+
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
14+
parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
15+
parser.add_argument("--timeout", default=60*24*30, type=int, help="Duration of the job")
16+
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
17+
parser.add_argument("--partition", default="train", type=str, help="the partition (default train).")
18+
return parser.parse_args()
19+
20+
21+
def get_shared_folder() -> Path:
22+
user = os.getenv("USER")
23+
path = "/data/checkpoints"
24+
if Path(path).is_dir():
25+
p = Path(f"{path}/{user}/experiments")
26+
p.mkdir(exist_ok=True)
27+
return p
28+
raise RuntimeError("No shared folder available")
29+
30+
31+
def get_init_file_folder() -> Path:
32+
user = os.getenv("USER")
33+
path = "/shared"
34+
if Path(path).is_dir():
35+
p = Path(f"{path}/{user}")
36+
p.mkdir(exist_ok=True)
37+
return p
38+
raise RuntimeError("No shared folder available")
39+
40+
41+
def get_init_file():
42+
# Init file must not exist, but it's parent dir must exist.
43+
os.makedirs(str(get_init_file_folder()), exist_ok=True)
44+
init_file = get_init_file_folder() / f"{uuid.uuid4().hex}_init"
45+
if init_file.exists():
46+
os.remove(str(init_file))
47+
return init_file
48+
49+
50+
class Trainer(object):
51+
def __init__(self, args):
52+
self.args = args
53+
54+
def __call__(self):
55+
import train
56+
57+
self._setup_gpu_args()
58+
train.main(self.args)
59+
60+
def checkpoint(self):
61+
import os
62+
import submitit
63+
from pathlib import Path
64+
65+
self.args.dist_url = get_init_file().as_uri()
66+
checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
67+
if os.path.exists(checkpoint_file):
68+
self.args.resume = checkpoint_file
69+
print("Requeuing ", self.args)
70+
empty_trainer = type(self)(self.args)
71+
return submitit.helpers.DelayedSubmission(empty_trainer)
72+
73+
def _setup_gpu_args(self):
74+
import submitit
75+
from pathlib import Path
76+
77+
job_env = submitit.JobEnvironment()
78+
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
79+
self.args.gpu = job_env.local_rank
80+
self.args.rank = job_env.global_rank
81+
self.args.world_size = job_env.num_tasks
82+
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
83+
84+
85+
def main():
86+
args = parse_args()
87+
if args.job_dir == "":
88+
args.job_dir = get_shared_folder() / "%j"
89+
90+
# Note that the folder will depend on the job_id, to easily track experiments
91+
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=300)
92+
93+
# cluster setup is defined by environment variables
94+
num_gpus_per_node = args.ngpus
95+
nodes = args.nodes
96+
timeout_min = args.timeout
97+
98+
executor.update_parameters(
99+
#mem_gb=96 * num_gpus_per_node, # 768GB per machine
100+
gpus_per_node=num_gpus_per_node,
101+
tasks_per_node=num_gpus_per_node, # one task per GPU
102+
cpus_per_task=12, # 96 cpus per machine
103+
nodes=nodes,
104+
timeout_min=timeout_min, # max is 60 * 72
105+
slurm_partition=args.partition,
106+
slurm_signal_delay_s=120,
107+
)
108+
109+
110+
executor.update_parameters(name="torchvision")
111+
112+
args.dist_url = get_init_file().as_uri()
113+
args.output_dir = args.job_dir
114+
115+
trainer = Trainer(args)
116+
job = executor.submit(trainer)
117+
118+
print("Submitted job_id:", job.job_id)
119+
120+
121+
if __name__ == "__main__":
122+
main()

references/classification/train.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import os
3+
import random
34
import time
45
import warnings
56

@@ -15,7 +16,7 @@
1516
from torchvision.transforms.functional import InterpolationMode
1617

1718

18-
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
19+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, scheduler=None):
1920
model.train()
2021
metric_logger = utils.MetricLogger(delimiter=" ")
2122
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
@@ -43,6 +44,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
4344
if args.clip_grad_norm is not None:
4445
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
4546
optimizer.step()
47+
48+
if scheduler is not None and args.lr_step_every_batch:
49+
scheduler.step()
4650

4751
if model_ema and i % args.model_ema_steps == 0:
4852
model_ema.update_parameters(model)
@@ -113,7 +117,7 @@ def _get_cache_path(filepath):
113117
def load_data(traindir, valdir, args):
114118
# Data loading code
115119
print("Loading data")
116-
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
120+
val_resize_size, val_crop_size, train_crop_size, center_crop, policy_magnitude = args.val_resize_size, args.val_crop_size, args.train_crop_size, args.train_center_crop, args.policy_magnitude
117121
interpolation = InterpolationMode(args.interpolation)
118122

119123
print("Loading training data")
@@ -129,10 +133,12 @@ def load_data(traindir, valdir, args):
129133
dataset = torchvision.datasets.ImageFolder(
130134
traindir,
131135
presets.ClassificationPresetTrain(
136+
center_crop=center_crop,
132137
crop_size=train_crop_size,
133138
interpolation=interpolation,
134139
auto_augment_policy=auto_augment_policy,
135140
random_erase_prob=random_erase_prob,
141+
policy_magnitude=policy_magnitude,
136142
),
137143
)
138144
if args.cache_dataset:
@@ -182,7 +188,12 @@ def load_data(traindir, valdir, args):
182188
def main(args):
183189
if args.output_dir:
184190
utils.mkdir(args.output_dir)
185-
191+
192+
if args.seed is None:
193+
# randomly choose a seed
194+
args.seed = random.randint(0, 2 ** 32)
195+
utils.set_seed(args.seed)
196+
186197
utils.init_distributed_mode(args)
187198
print(args)
188199

@@ -261,13 +272,21 @@ def main(args):
261272
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
262273

263274
scaler = torch.cuda.amp.GradScaler() if args.amp else None
275+
276+
batches_per_epoch = len(data_loader)
277+
warmup_iters = args.lr_warmup_epochs
278+
total_iters = args.epochs
279+
280+
if args.lr_step_every_batch:
281+
warmup_iters *= batches_per_epoch
282+
total_iters *= batches_per_epoch
264283

265284
args.lr_scheduler = args.lr_scheduler.lower()
266285
if args.lr_scheduler == "steplr":
267286
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
268287
elif args.lr_scheduler == "cosineannealinglr":
269288
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
270-
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
289+
optimizer, T_max=total_iters - warmup_iters, eta_min=args.lr_min
271290
)
272291
elif args.lr_scheduler == "exponentiallr":
273292
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
@@ -280,18 +299,18 @@ def main(args):
280299
if args.lr_warmup_epochs > 0:
281300
if args.lr_warmup_method == "linear":
282301
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
283-
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
302+
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
284303
)
285304
elif args.lr_warmup_method == "constant":
286305
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
287-
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
306+
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
288307
)
289308
else:
290309
raise RuntimeError(
291310
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
292311
)
293312
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
294-
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
313+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
295314
)
296315
else:
297316
lr_scheduler = main_lr_scheduler
@@ -341,8 +360,9 @@ def main(args):
341360
for epoch in range(args.start_epoch, args.epochs):
342361
if args.distributed:
343362
train_sampler.set_epoch(epoch)
344-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
345-
lr_scheduler.step()
363+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, lr_scheduler)
364+
if not args.lr_step_every_batch:
365+
lr_scheduler.step()
346366
evaluate(model, criterion, data_loader_test, device=device)
347367
if model_ema:
348368
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
@@ -371,7 +391,7 @@ def get_args_parser(add_help=True):
371391

372392
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
373393

374-
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
394+
parser.add_argument("--data-path", default="/datasets01_ontap/imagenet_full_size/061417/", type=str, help="dataset path")
375395
parser.add_argument("--model", default="resnet18", type=str, help="model name")
376396
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
377397
parser.add_argument(
@@ -425,6 +445,7 @@ def get_args_parser(add_help=True):
425445
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
426446
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
427447
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
448+
parser.add_argument("--lr-step-every-batch", action="store_true", help="decrease lr every step-size batches", default=False)
428449
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
429450
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
430451
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
@@ -448,6 +469,7 @@ def get_args_parser(add_help=True):
448469
action="store_true",
449470
)
450471
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
472+
parser.add_argument("--policy-magnitude", default=9, type=int, help="magnitude of auto augment policy")
451473
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
452474

453475
# Mixed precision training parameters
@@ -486,13 +508,16 @@ def get_args_parser(add_help=True):
486508
parser.add_argument(
487509
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
488510
)
511+
parser.add_argument(
512+
"--train-center-crop", action="store_true", help="use center crop instead of random crop for training (default: False)"
513+
)
489514
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
490515
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
491516
parser.add_argument(
492517
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
493518
)
494519
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
495-
520+
parser.add_argument("--seed", default=None, type=int, help="the seed for randomness (default: None). A `None` value means a seed will be randomly generated")
496521
return parser
497522

498523

references/classification/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import torch
1111
import torch.distributed as dist
12+
import numpy as np
13+
import random
1214

1315

1416
class SmoothedValue:
@@ -463,3 +465,13 @@ def _add_params(module, prefix=""):
463465
if len(params[key]) > 0:
464466
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465467
return param_groups
468+
469+
def set_seed(seed: int):
470+
"""
471+
Function for setting all the RNGs to the same seed
472+
"""
473+
torch.manual_seed(seed)
474+
torch.cuda.manual_seed(seed)
475+
torch.cuda.manual_seed_all(seed)
476+
np.random.seed(seed)
477+
random.seed(seed)

0 commit comments

Comments
 (0)