From f9a5a4006f008f402141e952180485c8b5bc3c15 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 20 Apr 2022 11:06:47 +0100 Subject: [PATCH 1/2] Convert weights only if `old_key` is in `state_dict` --- torchvision/models/detection/mask_rcnn.py | 3 ++- torchvision/models/detection/retinanet.py | 3 ++- torchvision/models/detection/rpn.py | 3 ++- torchvision/ops/feature_pyramid_network.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 65c85922e2a..59ab8b0946f 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -317,7 +317,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}mask_fcn{i+1}.{type}" new_key = f"{prefix}{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d277f130ad3..954dfdba1a1 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -45,7 +45,8 @@ def _v1_to_v2_weights(state_dict, prefix): for type in ["weight", "bias"]: old_key = f"{prefix}conv.{2*i}.{type}" new_key = f"{prefix}conv.{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) def _default_anchorgen(): diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 2b7ccb7b9ae..39f82ca323b 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -56,7 +56,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}conv.{type}" new_key = f"{prefix}conv.0.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 056ecbdc120..33e00c1deec 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -128,7 +128,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}{block}.{i}.{type}" new_key = f"{prefix}{block}.{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, From 43fcfcbd4e8375cfe37ad241a102b706d1678122 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 20 Apr 2022 11:24:22 +0100 Subject: [PATCH 2/2] Fix linter --- torchvision/ops/feature_pyramid_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 33e00c1deec..9062405a997 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -129,7 +129,7 @@ def _load_from_state_dict( old_key = f"{prefix}{block}.{i}.{type}" new_key = f"{prefix}{block}.{i}.0.{type}" if old_key in state_dict: - state_dict[new_key] = state_dict.pop(old_key) + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict,