Skip to content

Add DETR model #5922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
355 changes: 355 additions & 0 deletions torchvision/models/detection/detr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
import copy
import math
from typing import Optional, Callable, Tuple

import torch
from torch import nn, Tensor


def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class TransformerEncoderLayer(nn.Module):
"""Transormer Encoder Layer"""

def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
norm_first: bool = False,
activation_layer: Callable[..., nn.Module] = nn.ReLU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)

self.norm1 = norm_layer(d_model)
self.norm2 = norm_layer(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

self.activation = activation_layer()

self.norm_first = norm_first

def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos_embedding: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
x = self.norm1(x)
if pos_embedding is not None:
q = k = x + pos_embed
else:
q = k = x
x = src + self.dropout1(
self.self_attn(q, k, value=x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
)
if not self.norm_first:
x = self.norm1(x)
if self.norm_first:
x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(self.norm2(x))))))
else:
x = self.norm2(x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x))))))
return x


class TransformerEncoder(nn.Module):
"""Transformer Encoder"""

def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm

def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos_embed: Optional[Tensor] = None,
) -> Tensor:
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos_embed=pos_embed)
if self.norm is not None:
output = self.norm(output)
return output


class TransformerDecoder(nn.Module):
"""Transformer Decoder"""

def __init__(
self,
decoder_layer: nn.Module,
num_layers: int,
return_intermediate: bool = False,
norm: Optional[nn.Module] = None,
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate

def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos_embed: Optional[Tensor] = None,
query_pos_embed: Optional[Tensor] = None,
) -> Tensor:
output = tgt

intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos_embed=pos_embed,
query_pos=query_pos_embed,
)
if self.return_intermediate:
intermediate.append(self.norm(output))

if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)

if self.return_intermediate:
return torch.stack(intermediate)

return output.unsqueeze(0)


class TransformerDecoderLayer(nn.Module):
"""Transformer Decoder Layer"""

def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
norm_first: bool = False,
activation_layer: Callable[..., nn.Module] = nn.ReLU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)

self.norm1 = norm_layer(d_model)
self.norm2 = norm_layer(d_model)
self.norm3 = norm_layer(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

self.activation = activation_layer()
self.norm_first = norm_first

def _with_pos_embed(self, x: Tensor, pos_embed: Optional[Tensor] = None) -> Tensor:
return x + pos_embed if pos_embed is not None else x

def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos_embed: Optional[Tensor] = None,
query_pos_embed: Optional[Tensor] = None,
) -> None:
x = tgt
if self.norm_first:
x = self.norm1(x)
q = k = self._with_pos_embed(x, query_pos_embed)
x = tgt + self.dropout1(
self.self_attn(q, k, value=x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
)
if not self.norm_first:
x = self.norm1(x)
if self.norm_first:
x = x + self.dropout2(
self.multihead_attn(
query=self._with_pos_embed(self.norm2(x), query_pos_embed),
key=self._with_pos_embed(memory, pos_embed),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
)
x = x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(self.norm3(x))))))
else:
x = self.norm2(
x
+ self.dropout2(
self.multihead_attn(
query=self._with_pos_embed(x, query_pos_embed),
key=self._with_pos_embed(memory, pos_embed),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
)
)
x = self.norm3(x + self.dropout3(self.linear2(self.dropout(self.activation(self.linear1(x))))))


class Transformer(nn.Module):
"""Transformer"""

def __init__(
self,
d_model: int = 512,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
norm_first: bool = False,
return_intermediate_dec: bool = False,
activation_layer: Callable[..., nn.Module] = nn.ReLU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
):
super().__init__()
self.d_model = d_model
self.nhead = nhead

encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, norm_first, activation_layer, norm_layer
)
encoder_norm = norm_layer(d_model) if norm_first else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, norm_first, activation_layer, norm_layer
)
decoder_norm = norm_layer(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm)

for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def forward(self, src: Tensor, mask: Tensor, query_pos_embed: Tensor, pos_embed: Tensor) -> Tuple[Tensor, Tensor]:
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_pos_embed = query_pos_embed.unsqueeze(1).repeat(1, bs, 1)
mask = mask.flatten(1)

tgt = torch.zeros_like(query_pos_embed) # TODO: torch.fx
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_pos_embed)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)


class DETR:
pass


class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""

def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale

def forward(self, tensors: Tensor, mask: Tensor) -> Tensor:
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=tensors.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos


class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""

def __init__(self, num_pos_feats: int = 256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()

def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)

def forward(self, tensors: Tensor, mask: Optional[Tensor] = None) -> Tensor:
h, w = tensors.shape[-2:]
i = torch.arange(w, device=tensors.device)
j = torch.arange(h, device=tensors.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(tensors.shape[0], 1, 1, 1)
)
return pos