From 7dd2d4ab8af6af643d9c8bc1c6a166547f494b4a Mon Sep 17 00:00:00 2001 From: ntianhe ren Date: Sat, 4 Feb 2023 13:01:58 +0800 Subject: [PATCH 1/2] support encoder decoder checkpoint in dino --- projects/dino/configs/models/dino_r50.py | 3 +++ projects/dino/modeling/dino_transformer.py | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/projects/dino/configs/models/dino_r50.py b/projects/dino/configs/models/dino_r50.py index 5a3507df..1b07c2bf 100644 --- a/projects/dino/configs/models/dino_r50.py +++ b/projects/dino/configs/models/dino_r50.py @@ -56,6 +56,7 @@ num_layers=6, post_norm=False, num_feature_levels="${..num_feature_levels}", + use_checkpoint=False ), decoder=L(DINOTransformerDecoder)( embed_dim=256, @@ -66,6 +67,7 @@ num_layers=6, return_intermediate=True, num_feature_levels="${..num_feature_levels}", + use_checkpoint=False, ), num_feature_levels=4, two_stage_num_proposals="${..num_queries}", @@ -74,6 +76,7 @@ num_classes=80, num_queries=900, aux_loss=True, + use_checkpoint=False, criterion=L(DINOCriterion)( num_classes="${..num_classes}", matcher=L(HungarianMatcher)( diff --git a/projects/dino/modeling/dino_transformer.py b/projects/dino/modeling/dino_transformer.py index 5abd2510..eca04167 100644 --- a/projects/dino/modeling/dino_transformer.py +++ b/projects/dino/modeling/dino_transformer.py @@ -27,6 +27,7 @@ ) from detrex.utils import inverse_sigmoid +from fairscale.nn.checkpoint import checkpoint_wrapper class DINOTransformerEncoder(TransformerLayerSequence): def __init__( @@ -39,6 +40,7 @@ def __init__( num_layers: int = 6, post_norm: bool = False, num_feature_levels: int = 4, + use_checkpoint: bool = False, ): super(DINOTransformerEncoder, self).__init__( transformer_layers=BaseTransformerLayer( @@ -69,6 +71,11 @@ def __init__( else: self.post_norm_layer = None + # use encoder checkpoint + if use_checkpoint: + for layer in self.layers: + layer = checkpoint_wrapper(layer) + def forward( self, query, @@ -110,7 +117,8 @@ def __init__( num_layers: int = 6, return_intermediate: bool = True, num_feature_levels: int = 4, - look_forward_twice=True, + look_forward_twice: bool = True, + use_checkpoint: bool = True, ): super(DINOTransformerDecoder, self).__init__( transformer_layers=BaseTransformerLayer( @@ -149,6 +157,11 @@ def __init__( self.look_forward_twice = look_forward_twice self.norm = nn.LayerNorm(embed_dim) + # decoder checkpoint + if use_checkpoint: + for layer in self.layers: + layer = checkpoint_wrapper(layer) + def forward( self, query, From a6db12103d90dc5c7aea661a34874841351e268a Mon Sep 17 00:00:00 2001 From: ntianhe ren Date: Sat, 4 Feb 2023 13:05:01 +0800 Subject: [PATCH 2/2] refine config --- projects/dino/configs/models/dino_r50.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/dino/configs/models/dino_r50.py b/projects/dino/configs/models/dino_r50.py index 1b07c2bf..23e1df72 100644 --- a/projects/dino/configs/models/dino_r50.py +++ b/projects/dino/configs/models/dino_r50.py @@ -76,7 +76,6 @@ num_classes=80, num_queries=900, aux_loss=True, - use_checkpoint=False, criterion=L(DINOCriterion)( num_classes="${..num_classes}", matcher=L(HungarianMatcher)(