Skip to content

Commit 674e814

Browse files
authored
Add weight averaging and storing methods in references utils (#3352)
* Adding the average_checkpoints() method. * Adding the store_model_weights() method.
1 parent 03fec9c commit 674e814

File tree

1 file changed

+126
-1
lines changed

1 file changed

+126
-1
lines changed

references/classification/utils.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from collections import defaultdict, deque
1+
from collections import defaultdict, deque, OrderedDict
2+
import copy
23
import datetime
4+
import hashlib
35
import time
46
import torch
57
import torch.distributed as dist
@@ -252,3 +254,126 @@ def init_distributed_mode(args):
252254
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
253255
world_size=args.world_size, rank=args.rank)
254256
setup_for_distributed(args.rank == 0)
257+
258+
259+
def average_checkpoints(inputs):
260+
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
261+
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
262+
263+
Args:
264+
inputs (List[str]): An iterable of string paths of checkpoints to load from.
265+
Returns:
266+
A dict of string keys mapping to various values. The 'model' key
267+
from the returned dict should correspond to an OrderedDict mapping
268+
string parameter names to torch Tensors.
269+
"""
270+
params_dict = OrderedDict()
271+
params_keys = None
272+
new_state = None
273+
num_models = len(inputs)
274+
for fpath in inputs:
275+
with open(fpath, "rb") as f:
276+
state = torch.load(
277+
f,
278+
map_location=(
279+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
280+
),
281+
)
282+
# Copies over the settings from the first checkpoint
283+
if new_state is None:
284+
new_state = state
285+
model_params = state["model"]
286+
model_params_keys = list(model_params.keys())
287+
if params_keys is None:
288+
params_keys = model_params_keys
289+
elif params_keys != model_params_keys:
290+
raise KeyError(
291+
"For checkpoint {}, expected list of params: {}, "
292+
"but found: {}".format(f, params_keys, model_params_keys)
293+
)
294+
for k in params_keys:
295+
p = model_params[k]
296+
if isinstance(p, torch.HalfTensor):
297+
p = p.float()
298+
if k not in params_dict:
299+
params_dict[k] = p.clone()
300+
# NOTE: clone() is needed in case of p is a shared parameter
301+
else:
302+
params_dict[k] += p
303+
averaged_params = OrderedDict()
304+
for k, v in params_dict.items():
305+
averaged_params[k] = v
306+
if averaged_params[k].is_floating_point():
307+
averaged_params[k].div_(num_models)
308+
else:
309+
averaged_params[k] //= num_models
310+
new_state["model"] = averaged_params
311+
return new_state
312+
313+
314+
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True):
315+
"""
316+
This method can be used to prepare weights files for new models. It receives as
317+
input a model architecture and a checkpoint from the training script and produces
318+
a file with the weights ready for release.
319+
320+
Examples:
321+
from torchvision import models as M
322+
323+
# Classification
324+
model = M.mobilenet_v3_large(pretrained=False)
325+
print(store_model_weights(model, './class.pth'))
326+
327+
# Quantized Classification
328+
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
329+
model.fuse_model()
330+
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
331+
_ = torch.quantization.prepare_qat(model, inplace=True)
332+
print(store_model_weights(model, './qat.pth'))
333+
334+
# Object Detection
335+
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
336+
print(store_model_weights(model, './obj.pth'))
337+
338+
# Segmentation
339+
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
340+
print(store_model_weights(model, './segm.pth', strict=False))
341+
342+
Args:
343+
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
344+
checkpoint_path (str): The path of the checkpoint we will load.
345+
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
346+
Default: "model".
347+
strict (bool): whether to strictly enforce that the keys
348+
in :attr:`state_dict` match the keys returned by this module's
349+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
350+
351+
Returns:
352+
output_path (str): The location where the weights are saved.
353+
"""
354+
# Store the new model next to the checkpoint_path
355+
checkpoint_path = os.path.abspath(checkpoint_path)
356+
output_dir = os.path.dirname(checkpoint_path)
357+
358+
# Deep copy to avoid side-effects on the model object.
359+
model = copy.deepcopy(model)
360+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
361+
362+
# Load the weights to the model to validate that everything works
363+
# and remove unnecessary weights (such as auxiliaries, etc)
364+
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
365+
366+
tmp_path = os.path.join(output_dir, str(model.__hash__()))
367+
torch.save(model.state_dict(), tmp_path)
368+
369+
sha256_hash = hashlib.sha256()
370+
with open(tmp_path, "rb") as f:
371+
# Read and update hash string value in blocks of 4K
372+
for byte_block in iter(lambda: f.read(4096), b""):
373+
sha256_hash.update(byte_block)
374+
hh = sha256_hash.hexdigest()
375+
376+
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
377+
os.replace(tmp_path, output_path)
378+
379+
return output_path

0 commit comments

Comments
 (0)