From 95362d7a23de81bc4e7d200bbbcad02a5377f2ef Mon Sep 17 00:00:00 2001 From: priyanshu-hawk Date: Wed, 29 Mar 2023 20:40:10 +0530 Subject: [PATCH] Weight Init Update --- ML/Pytorch/GANs/2. DCGAN/model.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ML/Pytorch/GANs/2. DCGAN/model.py b/ML/Pytorch/GANs/2. DCGAN/model.py index 04b52d9d..2220acd7 100644 --- a/ML/Pytorch/GANs/2. DCGAN/model.py +++ b/ML/Pytorch/GANs/2. DCGAN/model.py @@ -15,7 +15,7 @@ def __init__(self, channels_img, features_d): super(Discriminator, self).__init__() self.disc = nn.Sequential( # input: N x channels_img x 64 x 64 - nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), + nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1, bias=False), nn.LeakyReLU(0.2), # _block(in_channels, out_channels, kernel_size, stride, padding) self._block(features_d, features_d * 2, 4, 2, 1), @@ -36,8 +36,8 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding): padding, bias=False, ), - # nn.BatchNorm2d(out_channels), - nn.LeakyReLU(0.2), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(0.2, True), ) def forward(self, x): @@ -54,7 +54,7 @@ def __init__(self, channels_noise, channels_img, features_g): self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16 self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32 nn.ConvTranspose2d( - features_g * 2, channels_img, kernel_size=4, stride=2, padding=1 + features_g * 2, channels_img, kernel_size=4, stride=2, padding=1, bias=False ), # Output: N x channels_img x 64 x 64 nn.Tanh(), @@ -70,7 +70,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding): padding, bias=False, ), - # nn.BatchNorm2d(out_channels), + nn.BatchNorm2d(out_channels), nn.ReLU(), ) @@ -78,11 +78,13 @@ def forward(self, x): return self.net(x) -def initialize_weights(model): - # Initializes weights according to the DCGAN paper - for m in model.modules(): - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)): - nn.init.normal_(m.weight.data, 0.0, 0.02) +def initialize_weights(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: # this will prevent Discriminator form converging to 0 + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) def test():