From ef079d3e47e88857ea09d430150bf103abb35405 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 29 Apr 2022 10:02:40 +0800 Subject: [PATCH 1/8] Create detr.py --- torchvision/models/detection/detr.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 torchvision/models/detection/detr.py diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py new file mode 100644 index 00000000000..fa18967ffc1 --- /dev/null +++ b/torchvision/models/detection/detr.py @@ -0,0 +1,5 @@ +import torch + + +class DETR: + pass From 8d336048fbd1c3b14dee875aec501fb45ba1afa8 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 8 May 2022 21:48:32 +0800 Subject: [PATCH 2/8] add transformer --- torchvision/models/detection/detr.py | 254 +++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index fa18967ffc1..b659afcc6b6 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -1,4 +1,258 @@ +import copy +from typing import Optional, List, Callable + import torch +from torch import nn, Tensor + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class TransformerEncoderLayer(nn.Module): + """"Transormer Encoder Layer""" + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + norm_first: bool = False, + activation_layer: Callable[..., nn.Module] = nn.ReLU, + norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = norm_layer(d_model) + self.norm2 = norm_layer(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = activation_layer() + + self.norm_first = norm_first + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos_embedding: Optional[Tensor] = None + ): + x = src + if self.norm_first: + x = self.norm1(x) + if pos_embedding is not None: + q = k = x + pos_embedding + else: + q = k = x + x = src + self.dropout1(self.self_attn(q, k, value=x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]) + if not self.norm_first: + x = self.norm1(x) + if self.norm_first: + x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x)))))) + else: + x = self.norm2(x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))) + return x + + +class TransformerEncoder(nn.Module): + """Transformer Encoder""" + + def __init__(self, encoder_layer: nn.Module, num_layers: int , norm: Optional[nn.Module] = None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos_embed: Optional[Tensor] = None + ): + output = src + for layer in self.layers: + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + pos_embedding=pos_embed) + if self.norm is not None: + output = self.norm(output) + return output + + +class TransformerDecoder(nn.Module): + """Transformer Decoder""" + + def __init__( + self, + decoder_layer: nn.Module, + num_layers: int, + return_intermediate: bool = False, + norm: Optional[nn.Module] = None + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos_embed: Optional[Tensor] = None, + query_pos_embed: Optional[Tensor] = None + ): + output = tgt + + intermediate = [] + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos_embed, query_pos=query_pos_embed) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerDecoderLayer(nn.Module): + """Transformer Decoder Layer""" + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + norm_first: bool = False, + activation_layer: Callable[..., nn.Module] = nn.ReLU, + norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = norm_layer(d_model) + self.norm2 = norm_layer(d_model) + self.norm3 = norm_layer(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = activation_layer() + self.norm_first = norm_first + + def _with_pos_embed(x: Tensor, pos_embed: Tensor = None): + return x + pos_embed if pos_embed is not None else x + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos_embed: Optional[Tensor] = None, + query_pos_embed: Optional[Tensor] = None + ): + x = tgt + if self.norm_first: + x = self.norm1(x) + q = k = self._with_pos_embed(x, query_pos_embed) + x = tgt + self.dropout1(self.self_attn(q, k, value=x, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0]) + if not self.norm_first: + x = self.norm1(x) + if self.norm_first: + x = x + self.dropout2(self.multihead_attn( + query=self.with_pos_embed(self.norm2(x), query_pos_embed), + key=self.with_pos_embed(memory, pos_embed), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0]) + x = x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(self.norm3(x)))))) + else: + x = self.norm2(x + self.dropout2(self.multihead_attn( + query=self.with_pos_embed(x, query_pos_embed), + key=self.with_pos_embed(memory, pos_embed), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0])) + x = self.norm3(x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(x)))))) + + +class Transformer(nn.Module): + """Transformer""" + + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + norm_first: bool = False, + return_intermediate_dec: bool = False, + activation_layer: Callable[..., nn.Module] = nn.ReLU, + norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm + ): + super().__init__() + self.d_model = d_model + self.nhead = nhead + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, norm_first, activation_layer, norm_layer) + encoder_norm = norm_layer(d_model) if norm_first else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, norm_first, activation_layer, norm_layer) + decoder_norm = norm_layer(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src: Tensor, mask: Tensor, query_embed: Tensor, pos_embed: Tensor): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) class DETR: From 73ca0e8304d244756919201f48ce5ce58808b0cb Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 8 May 2022 22:27:58 +0800 Subject: [PATCH 3/8] fix lint --- torchvision/models/detection/detr.py | 104 ++++++++++++++++----------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index b659afcc6b6..6aebd6c03eb 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -1,5 +1,5 @@ import copy -from typing import Optional, List, Callable +from typing import Optional, Callable import torch from torch import nn, Tensor @@ -10,8 +10,8 @@ def _get_clones(module, N): class TransformerEncoderLayer(nn.Module): - """"Transormer Encoder Layer""" - + """Transormer Encoder Layer""" + def __init__( self, d_model: int, @@ -34,7 +34,7 @@ def __init__( self.dropout2 = nn.Dropout(dropout) self.activation = activation_layer() - + self.norm_first = norm_first def forward( @@ -42,20 +42,22 @@ def forward( src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - pos_embedding: Optional[Tensor] = None + pos_embed: Optional[Tensor] = None, ): x = src if self.norm_first: x = self.norm1(x) if pos_embedding is not None: - q = k = x + pos_embedding + q = k = x + pos_embed else: q = k = x x = src + self.dropout1(self.self_attn(q, k, value=x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]) if not self.norm_first: x = self.norm1(x) if self.norm_first: - x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x)))))) + x = x + self.dropout2( + self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x))))) + ) else: x = self.norm2(x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))) return x @@ -63,8 +65,8 @@ def forward( class TransformerEncoder(nn.Module): """Transformer Encoder""" - - def __init__(self, encoder_layer: nn.Module, num_layers: int , norm: Optional[nn.Module] = None): + + def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @@ -75,12 +77,11 @@ def forward( src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - pos_embed: Optional[Tensor] = None + pos_embed: Optional[Tensor] = None, ): output = src for layer in self.layers: - output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - pos_embedding=pos_embed) + output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos_embed=pos_embed) if self.norm is not None: output = self.norm(output) return output @@ -88,13 +89,13 @@ def forward( class TransformerDecoder(nn.Module): """Transformer Decoder""" - + def __init__( self, decoder_layer: nn.Module, num_layers: int, return_intermediate: bool = False, - norm: Optional[nn.Module] = None + norm: Optional[nn.Module] = None, ): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) @@ -111,17 +112,22 @@ def forward( tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, - query_pos_embed: Optional[Tensor] = None + query_pos_embed: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: - output = layer(output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - pos=pos_embed, query_pos=query_pos_embed) + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos_embed=pos_embed, + query_pos=query_pos_embed + ) if self.return_intermediate: intermediate.append(self.norm(output)) @@ -167,7 +173,7 @@ def __init__( self.activation = activation_layer() self.norm_first = norm_first - + def _with_pos_embed(x: Tensor, pos_embed: Tensor = None): return x + pos_embed if pos_embed is not None else x @@ -180,35 +186,46 @@ def forward( tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, - query_pos_embed: Optional[Tensor] = None + query_pos_embed: Optional[Tensor] = None, ): x = tgt if self.norm_first: x = self.norm1(x) q = k = self._with_pos_embed(x, query_pos_embed) - x = tgt + self.dropout1(self.self_attn(q, k, value=x, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0]) + x = tgt + self.dropout1( + self.self_attn(q, k, value=x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + ) if not self.norm_first: x = self.norm1(x) if self.norm_first: - x = x + self.dropout2(self.multihead_attn( - query=self.with_pos_embed(self.norm2(x), query_pos_embed), - key=self.with_pos_embed(memory, pos_embed), - value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask)[0]) + x = x + self.dropout2( + self.multihead_attn( + query=self.with_pos_embed(self.norm2(x), query_pos_embed), + key=self.with_pos_embed(memory, pos_embed), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask + )[0] + ) x = x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(self.norm3(x)))))) else: - x = self.norm2(x + self.dropout2(self.multihead_attn( - query=self.with_pos_embed(x, query_pos_embed), - key=self.with_pos_embed(memory, pos_embed), - value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask)[0])) + x = self.norm2( + x + + self.dropout2( + self.multihead_attn( + query=self.with_pos_embed(x, query_pos_embed), + key=self.with_pos_embed(memory, pos_embed), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask + )[0] + ) + ) x = self.norm3(x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(x)))))) class Transformer(nn.Module): """Transformer""" - + def __init__( self, d_model: int = 512, @@ -220,19 +237,21 @@ def __init__( norm_first: bool = False, return_intermediate_dec: bool = False, activation_layer: Callable[..., nn.Module] = nn.ReLU, - norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm + norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, ): - super().__init__() + super().__init__() self.d_model = d_model self.nhead = nhead - encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, - dropout, norm_first, activation_layer, norm_layer) + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, norm_first, activation_layer, norm_layer + ) encoder_norm = norm_layer(d_model) if norm_first else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, - dropout, norm_first, activation_layer, norm_layer) + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, norm_first, activation_layer, norm_layer + ) decoder_norm = norm_layer(d_model) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm) @@ -250,8 +269,7 @@ def forward(self, src: Tensor, mask: Tensor, query_embed: Tensor, pos_embed: Ten tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, - pos=pos_embed, query_pos=query_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) From f642bee25ee0c11d5939ba709ec38ac3ffdd946a Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 8 May 2022 22:43:49 +0800 Subject: [PATCH 4/8] fix lint --- torchvision/models/detection/detr.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index 6aebd6c03eb..16ccf049e5d 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -51,13 +51,13 @@ def forward( q = k = x + pos_embed else: q = k = x - x = src + self.dropout1(self.self_attn(q, k, value=x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]) + x = src + self.dropout1( + self.self_attn(q, k, value=x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + ) if not self.norm_first: x = self.norm1(x) if self.norm_first: - x = x + self.dropout2( - self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x))))) - ) + x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x)))))) else: x = self.norm2(x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))) return x @@ -126,7 +126,7 @@ def forward( tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos_embed=pos_embed, - query_pos=query_pos_embed + query_pos=query_pos_embed, ) if self.return_intermediate: intermediate.append(self.norm(output)) @@ -202,8 +202,9 @@ def forward( self.multihead_attn( query=self.with_pos_embed(self.norm2(x), query_pos_embed), key=self.with_pos_embed(memory, pos_embed), - value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, )[0] ) x = x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(self.norm3(x)))))) @@ -216,12 +217,12 @@ def forward( key=self.with_pos_embed(memory, pos_embed), value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask + key_padding_mask=memory_key_padding_mask, )[0] ) ) x = self.norm3(x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(x)))))) - + class Transformer(nn.Module): """Transformer""" @@ -259,17 +260,17 @@ def __init__( if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, src: Tensor, mask: Tensor, query_embed: Tensor, pos_embed: Tensor): + def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: Tensor): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + query_pos_embed = query_pos_embed.unsqueeze(1).repeat(1, bs, 1) mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_pos_embed) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) From daae9ee837dbad267966d3cd0b08fcc0ac2f303b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 8 May 2022 22:52:59 +0800 Subject: [PATCH 5/8] fix bug --- torchvision/models/detection/detr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index 16ccf049e5d..2b020d5b585 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -174,7 +174,7 @@ def __init__( self.activation = activation_layer() self.norm_first = norm_first - def _with_pos_embed(x: Tensor, pos_embed: Tensor = None): + def _with_pos_embed(self, x: Tensor, pos_embed: Tensor = None): return x + pos_embed if pos_embed is not None else x def forward( @@ -200,8 +200,8 @@ def forward( if self.norm_first: x = x + self.dropout2( self.multihead_attn( - query=self.with_pos_embed(self.norm2(x), query_pos_embed), - key=self.with_pos_embed(memory, pos_embed), + query=self._with_pos_embed(self.norm2(x), query_pos_embed), + key=self._with_pos_embed(memory, pos_embed), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, @@ -213,8 +213,8 @@ def forward( x + self.dropout2( self.multihead_attn( - query=self.with_pos_embed(x, query_pos_embed), - key=self.with_pos_embed(memory, pos_embed), + query=self._with_pos_embed(x, query_pos_embed), + key=self._with_pos_embed(memory, pos_embed), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, @@ -268,7 +268,7 @@ def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: query_pos_embed = query_pos_embed.unsqueeze(1).repeat(1, bs, 1) mask = mask.flatten(1) - tgt = torch.zeros_like(query_embed) + tgt = torch.zeros_like(query_embed) # TODO: torch.fx memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_pos_embed) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) From 219ef477889737a1a2b14a33486ec6fe272f11e5 Mon Sep 17 00:00:00 2001 From: Song Lin <865626@163.com> Date: Mon, 9 May 2022 22:44:52 +0800 Subject: [PATCH 6/8] Add PositionEmbedding --- torchvision/models/detection/detr.py | 69 ++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index 2b020d5b585..fee7e85577e 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -1,10 +1,12 @@ import copy from typing import Optional, Callable +import math import torch from torch import nn, Tensor + def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) @@ -276,3 +278,70 @@ def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: class DETR: pass + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensors,mask): + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=tensors.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensors,mask=None): + h, w = tensors.shape[-2:] + i = torch.arange(w, device=tensors.device) + j = torch.arange(h, device=tensors.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(tensors.shape[0], 1, 1, 1) + return pos + + +if __name__ == '__main__': From de3dfb45b92ca1a3b5f71df6ffdf0cf7dd7a4f3c Mon Sep 17 00:00:00 2001 From: tripleMu <865626@163.com> Date: Fri, 13 May 2022 20:30:59 +0800 Subject: [PATCH 7/8] Add type hints and fix some wrong type --- torchvision/models/detection/detr.py | 58 ++++++++++++++++------------ 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index fee7e85577e..82a3720a570 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -1,13 +1,12 @@ import copy -from typing import Optional, Callable - import math +from typing import Optional, Callable, Tuple + import torch from torch import nn, Tensor - -def _get_clones(module, N): +def _get_clones(module: nn.Module, N: int) -> nn.Module: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) @@ -22,7 +21,7 @@ def __init__( dropout: float = 0.1, norm_first: bool = False, activation_layer: Callable[..., nn.Module] = nn.ReLU, - norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) @@ -45,7 +44,7 @@ def forward( src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, - ): + ) -> Tensor: x = src if self.norm_first: x = self.norm1(x) @@ -80,7 +79,7 @@ def forward( mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, - ): + ) -> Tensor: output = src for layer in self.layers: output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos_embed=pos_embed) @@ -115,7 +114,7 @@ def forward( memory_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, query_pos_embed: Optional[Tensor] = None, - ): + ) -> Tensor: output = tgt intermediate = [] @@ -156,7 +155,7 @@ def __init__( dropout: float = 0.1, norm_first: bool = False, activation_layer: Callable[..., nn.Module] = nn.ReLU, - norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) @@ -176,7 +175,7 @@ def __init__( self.activation = activation_layer() self.norm_first = norm_first - def _with_pos_embed(self, x: Tensor, pos_embed: Tensor = None): + def _with_pos_embed(self, x: Tensor, pos_embed: Optional[Tensor] = None) -> Tensor: return x + pos_embed if pos_embed is not None else x def forward( @@ -189,7 +188,7 @@ def forward( memory_key_padding_mask: Optional[Tensor] = None, pos_embed: Optional[Tensor] = None, query_pos_embed: Optional[Tensor] = None, - ): + ) -> None: x = tgt if self.norm_first: x = self.norm1(x) @@ -240,7 +239,7 @@ def __init__( norm_first: bool = False, return_intermediate_dec: bool = False, activation_layer: Callable[..., nn.Module] = nn.ReLU, - norm_layer: Callable[..., torch.nn.Module] = nn.LayerNorm, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ): super().__init__() self.d_model = d_model @@ -262,7 +261,7 @@ def __init__( if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: Tensor): + def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: Tensor) -> Tuple[Tensor, Tensor]: # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = src.flatten(2).permute(2, 0, 1) @@ -270,7 +269,7 @@ def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: query_pos_embed = query_pos_embed.unsqueeze(1).repeat(1, bs, 1) mask = mask.flatten(1) - tgt = torch.zeros_like(query_embed) # TODO: torch.fx + tgt = torch.zeros_like(query_embed) # TODO: torch.fx memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_pos_embed) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) @@ -285,7 +284,10 @@ class PositionEmbeddingSine(nn.Module): This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature @@ -296,7 +298,7 @@ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=N scale = 2 * math.pi self.scale = scale - def forward(self, tensors,mask): + def forward(self, tensors: Tensor, mask: Tensor) -> Tensor: assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) @@ -321,7 +323,8 @@ class PositionEmbeddingLearned(nn.Module): """ Absolute pos embedding, learned. """ - def __init__(self, num_pos_feats=256): + + def __init__(self, num_pos_feats: int = 256): super().__init__() self.row_embed = nn.Embedding(50, num_pos_feats) self.col_embed = nn.Embedding(50, num_pos_feats) @@ -331,17 +334,22 @@ def reset_parameters(self): nn.init.uniform_(self.row_embed.weight) nn.init.uniform_(self.col_embed.weight) - def forward(self, tensors,mask=None): + def forward(self, tensors: Tensor, mask: Optional[Tensor] = None) -> Tensor: h, w = tensors.shape[-2:] i = torch.arange(w, device=tensors.device) j = torch.arange(h, device=tensors.device) x_emb = self.col_embed(i) y_emb = self.row_embed(j) - pos = torch.cat([ - x_emb.unsqueeze(0).repeat(h, 1, 1), - y_emb.unsqueeze(1).repeat(1, w, 1), - ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(tensors.shape[0], 1, 1, 1) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(tensors.shape[0], 1, 1, 1) + ) return pos - - -if __name__ == '__main__': From 8de1dba59efe47afb45a8b9abbe048522a0b316b Mon Sep 17 00:00:00 2001 From: q3394101 <865626@163.com> Date: Tue, 17 May 2022 14:32:57 +0800 Subject: [PATCH 8/8] Fix error --- torchvision/models/detection/detr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/detr.py b/torchvision/models/detection/detr.py index 82a3720a570..ec99bd86095 100644 --- a/torchvision/models/detection/detr.py +++ b/torchvision/models/detection/detr.py @@ -6,7 +6,7 @@ from torch import nn, Tensor -def _get_clones(module: nn.Module, N: int) -> nn.Module: +def _get_clones(module: nn.Module, N: int) -> nn.ModuleList: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) @@ -43,7 +43,7 @@ def forward( src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - pos_embed: Optional[Tensor] = None, + pos_embedding: Optional[Tensor] = None, ) -> Tensor: x = src if self.norm_first: @@ -269,7 +269,7 @@ def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: query_pos_embed = query_pos_embed.unsqueeze(1).repeat(1, bs, 1) mask = mask.flatten(1) - tgt = torch.zeros_like(query_embed) # TODO: torch.fx + tgt = torch.zeros_like(query_pos_embed) # TODO: torch.fx memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_pos_embed) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)