11
11
from torch import nn
12
12
13
13
14
+ try :
15
+ from torchvision .prototype import models as PM
16
+ except ImportError :
17
+ PM = None
18
+
19
+
14
20
def get_dataset (dir_path , name , image_set , transform ):
15
21
def sbd (* args , ** kwargs ):
16
22
return torchvision .datasets .SBDataset (* args , mode = "segmentation" , ** kwargs )
@@ -26,11 +32,15 @@ def sbd(*args, **kwargs):
26
32
return ds , num_classes
27
33
28
34
29
- def get_transform (train ):
30
- base_size = 520
31
- crop_size = 480
32
-
33
- return presets .SegmentationPresetTrain (base_size , crop_size ) if train else presets .SegmentationPresetEval (base_size )
35
+ def get_transform (train , args ):
36
+ if train :
37
+ return presets .SegmentationPresetTrain (base_size = 520 , crop_size = 480 )
38
+ elif not args .weights :
39
+ return presets .SegmentationPresetEval (base_size = 520 )
40
+ else :
41
+ fn = PM .segmentation .__dict__ [args .model ]
42
+ weights = PM ._api .get_weight (fn , args .weights )
43
+ return weights .transforms ()
34
44
35
45
36
46
def criterion (inputs , target ):
@@ -90,8 +100,8 @@ def main(args):
90
100
91
101
device = torch .device (args .device )
92
102
93
- dataset , num_classes = get_dataset (args .data_path , args .dataset , "train" , get_transform (train = True ))
94
- dataset_test , _ = get_dataset (args .data_path , args .dataset , "val" , get_transform (train = False ))
103
+ dataset , num_classes = get_dataset (args .data_path , args .dataset , "train" , get_transform (True , args ))
104
+ dataset_test , _ = get_dataset (args .data_path , args .dataset , "val" , get_transform (False , args ))
95
105
96
106
if args .distributed :
97
107
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset )
@@ -113,9 +123,18 @@ def main(args):
113
123
dataset_test , batch_size = 1 , sampler = test_sampler , num_workers = args .workers , collate_fn = utils .collate_fn
114
124
)
115
125
116
- model = torchvision .models .segmentation .__dict__ [args .model ](
117
- num_classes = num_classes , aux_loss = args .aux_loss , pretrained = args .pretrained
118
- )
126
+ if not args .weights :
127
+ model = torchvision .models .segmentation .__dict__ [args .model ](
128
+ pretrained = args .pretrained ,
129
+ num_classes = num_classes ,
130
+ aux_loss = args .aux_loss ,
131
+ )
132
+ else :
133
+ if PM is None :
134
+ raise ImportError ("The prototype module couldn't be found. Please install the latest torchvision nightly." )
135
+ model = PM .segmentation .__dict__ [args .model ](
136
+ weights = args .weights , num_classes = num_classes , aux_loss = args .aux_loss
137
+ )
119
138
model .to (device )
120
139
if args .distributed :
121
140
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -247,6 +266,9 @@ def get_args_parser(add_help=True):
247
266
parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
248
267
parser .add_argument ("--dist-url" , default = "env://" , type = str , help = "url used to set up distributed training" )
249
268
269
+ # Prototype models only
270
+ parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
271
+
250
272
return parser
251
273
252
274
0 commit comments