Skip to content

Commit 0e8155a

Browse files
committed
upload files
1 parent 2ab46da commit 0e8155a

8 files changed

+1624
-0
lines changed

Blocks.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
'''def conv_block(in_dim, out_dim, act_fn):
6+
model = nn.Sequential(
7+
nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
8+
nn.BatchNorm2d(out_dim),
9+
act_fn,
10+
)
11+
return model
12+
'''
13+
14+
def conv(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv2d,
15+
BN=False, ws=False, activ=nn.LeakyReLU(0.2), gainWS=2):
16+
convlayer = layer(nin, nout, kernel_size, stride=stride, padding=padding, bias=bias)
17+
layers = []
18+
if ws:
19+
layers.append(WScaleLayer(convlayer, gain=gainWS))
20+
if BN:
21+
layers.append(nn.BatchNorm2d(nout))
22+
if activ is not None:
23+
if activ == nn.PReLU:
24+
# to avoid sharing the same parameter, activ must be set to nn.PReLU (without '()')
25+
layers.append(activ(num_parameters=1))
26+
else:
27+
# if activ == nn.PReLU(), the parameter will be shared for the whole network !
28+
layers.append(activ)
29+
layers.insert(ws, convlayer)
30+
return nn.Sequential(*layers)
31+
32+
class ResidualConv(nn.Module):
33+
def __init__(self, nin, nout, bias=False, BN=False, ws=False, activ=nn.LeakyReLU(0.2)):
34+
super(ResidualConv, self).__init__()
35+
36+
convs = [conv(nin, nout, bias=bias, BN=BN, ws=ws, activ=activ),
37+
conv(nout, nout, bias=bias, BN=BN, ws=ws, activ=None)]
38+
self.convs = nn.Sequential(*convs)
39+
40+
res = []
41+
if nin != nout:
42+
res.append(conv(nin, nout, kernel_size=1, padding=0, bias=False, BN=BN, ws=ws, activ=None))
43+
self.res = nn.Sequential(*res)
44+
45+
activation = []
46+
if activ is not None:
47+
if activ == nn.PReLU:
48+
# to avoid sharing the same parameter, activ must be set to nn.PReLU (without '()')
49+
activation.append(activ(num_parameters=1))
50+
else:
51+
# if activ == nn.PReLU(), the parameter will be shared for the whole network !
52+
activation.append(activ)
53+
self.activation = nn.Sequential(*activation)
54+
55+
def forward(self, input):
56+
out = self.convs(input)
57+
return self.activation(out + self.res(input))
58+
59+
60+
def upSampleConv_Res(nin, nout, upscale=2, bias=False, BN=False, ws=False, activ=nn.LeakyReLU(0.2)):
61+
return nn.Sequential(
62+
nn.Upsample(scale_factor=upscale),
63+
ResidualConv(nin, nout, bias=bias, BN=BN, ws=ws, activ=activ)
64+
)
65+
66+
67+
68+
def conv_block(in_dim, out_dim, act_fn, kernel_size=3, stride=1, padding=1, dilation=1 ):
69+
model = nn.Sequential(
70+
nn.Conv2d(in_dim, out_dim, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation ),
71+
nn.BatchNorm2d(out_dim),
72+
act_fn,
73+
)
74+
return model
75+
76+
def conv_block_1(in_dim, out_dim):
77+
model = nn.Sequential(
78+
nn.Conv2d(in_dim, out_dim, kernel_size=1),
79+
nn.BatchNorm2d(out_dim),
80+
nn.PReLU(),
81+
)
82+
return model
83+
84+
def conv_block_Asym(in_dim, out_dim, kernelSize):
85+
model = nn.Sequential(
86+
nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([2,0])),
87+
nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,2])),
88+
nn.BatchNorm2d(out_dim),
89+
nn.PReLU(),
90+
)
91+
return model
92+
93+
94+
def conv_block_Asym_Inception(in_dim, out_dim, kernel_size, padding, dilation=1):
95+
model = nn.Sequential(
96+
nn.Conv2d(in_dim, out_dim, kernel_size=[kernel_size,1], padding=tuple([padding*dilation,0]), dilation = (dilation,1)),
97+
nn.BatchNorm2d(out_dim),
98+
nn.ReLU(),
99+
nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernel_size], padding=tuple([0,padding*dilation]), dilation = (dilation,1)),
100+
nn.BatchNorm2d(out_dim),
101+
nn.ReLU(),
102+
)
103+
return model
104+
105+
106+
def conv_block_Asym_Inception_WithIncreasedFeatMaps(in_dim, mid_dim, out_dim, kernel_size, padding, dilation=1):
107+
model = nn.Sequential(
108+
nn.Conv2d(in_dim, mid_dim, kernel_size=[kernel_size,1], padding=tuple([padding*dilation,0]), dilation = (dilation,1)),
109+
nn.BatchNorm2d(mid_dim),
110+
nn.ReLU(),
111+
nn.Conv2d(mid_dim, out_dim, kernel_size=[1, kernel_size], padding=tuple([0,padding*dilation]), dilation = (dilation,1)),
112+
nn.BatchNorm2d(out_dim),
113+
nn.ReLU(),
114+
)
115+
return model
116+
117+
118+
def conv_block_Asym_ERFNet(in_dim, out_dim, kernelSize, padding, drop, dilation):
119+
model = nn.Sequential(
120+
nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([padding,0]), bias = True),
121+
nn.ReLU(),
122+
nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,padding]), bias = True),
123+
nn.BatchNorm2d(out_dim, eps=1e-03),
124+
nn.ReLU(),
125+
nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize,1], padding=tuple([padding*dilation,0]), bias=True, dilation = (dilation,1)),
126+
nn.ReLU(),
127+
nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0,padding*dilation]), bias=True, dilation = (1, dilation)),
128+
nn.BatchNorm2d(out_dim, eps=1e-03),
129+
nn.Dropout2d(drop),
130+
)
131+
return model
132+
133+
def conv_block_3_3(in_dim, out_dim):
134+
model = nn.Sequential(
135+
nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
136+
nn.BatchNorm2d(out_dim),
137+
nn.PReLU(),
138+
)
139+
return model
140+
141+
# TODO: Change order of block: BN + Activation + Conv
142+
def conv_decod_block(in_dim, out_dim, act_fn):
143+
model = nn.Sequential(
144+
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
145+
nn.BatchNorm2d(out_dim),
146+
act_fn,
147+
)
148+
return model
149+
150+
def dilation_conv_block(in_dim,out_dim,act_fn,stride_val,dil_val):
151+
model = nn.Sequential(
152+
nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=stride_val, padding=1, dilation=dil_val),
153+
nn.BatchNorm2d(out_dim),
154+
act_fn,
155+
)
156+
return model
157+
158+
def maxpool():
159+
pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
160+
return pool
161+
162+
163+
def avrgpool05():
164+
pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
165+
return pool
166+
167+
168+
def avrgpool025():
169+
pool = nn.AvgPool2d(kernel_size=2, stride=4, padding=0)
170+
return pool
171+
172+
173+
def avrgpool0125():
174+
pool = nn.AvgPool2d(kernel_size=2, stride=8, padding=0)
175+
return pool
176+
177+
178+
def maxpool():
179+
pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
180+
return pool
181+
182+
def maxpool_1_4():
183+
pool = nn.MaxPool2d(kernel_size=2, stride=4, padding=0)
184+
return pool
185+
186+
def maxpool_1_8():
187+
pool = nn.MaxPool2d(kernel_size=2, stride=8, padding=0)
188+
return pool
189+
190+
def maxpool_1_16():
191+
pool = nn.MaxPool2d(kernel_size=2, stride=16, padding=0)
192+
return pool
193+
194+
def maxpool_1_32():
195+
pool = nn.MaxPool2d(kernel_size=2, stride=32, padding=0)
196+
197+
198+
def conv_block_3(in_dim, out_dim, act_fn):
199+
model = nn.Sequential(
200+
conv_block(in_dim, out_dim, act_fn),
201+
conv_block(out_dim, out_dim, act_fn),
202+
nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
203+
nn.BatchNorm2d(out_dim),
204+
)
205+
return model
206+
207+
208+
209+
def classificationNet(D_in):
210+
H = 400
211+
D_out = 1
212+
model = torch.nn.Sequential(
213+
torch.nn.Linear(D_in, H),
214+
torch.nn.ReLU(),
215+
torch.nn.Linear(H, int(H / 4)),
216+
torch.nn.ReLU(),
217+
torch.nn.Linear(int(H / 4), D_out)
218+
)
219+
220+
return model

0 commit comments

Comments
 (0)