diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 65c85922e2a..59ab8b0946f 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -317,7 +317,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}mask_fcn{i+1}.{type}" new_key = f"{prefix}{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d277f130ad3..954dfdba1a1 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -45,7 +45,8 @@ def _v1_to_v2_weights(state_dict, prefix): for type in ["weight", "bias"]: old_key = f"{prefix}conv.{2*i}.{type}" new_key = f"{prefix}conv.{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) def _default_anchorgen(): diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 2b7ccb7b9ae..39f82ca323b 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -56,7 +56,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}conv.{type}" new_key = f"{prefix}conv.0.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict, diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 056ecbdc120..9062405a997 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -128,7 +128,8 @@ def _load_from_state_dict( for type in ["weight", "bias"]: old_key = f"{prefix}{block}.{i}.{type}" new_key = f"{prefix}{block}.{i}.0.{type}" - state_dict[new_key] = state_dict.pop(old_key) + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) super()._load_from_state_dict( state_dict,