Skip to content

Commit 3f93283

Browse files
authored
Merge pull request #155 from lcskrishna/cl/indentation
fixed flake8 issues
2 parents a75f8e4 + e68953f commit 3f93283

File tree

5 files changed

+274
-224
lines changed

5 files changed

+274
-224
lines changed

test/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def wrapper(*args, **kwargs):
107107
fn(*args, **kwargs)
108108
return wrapper
109109

110+
110111
def skipIfNoLapack(fn):
111112
@wraps(fn)
112113
def wrapper(*args, **kwargs):

test/common_nn.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_weight(m):
4141
constructor_args=(10, 8),
4242
input_size=(4, 10),
4343
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
44-
test_cuda = (not TEST_WITH_ROCM)
44+
test_cuda=(not TEST_WITH_ROCM)
4545
),
4646
dict(
4747
module_name='Linear',
@@ -103,28 +103,28 @@ def get_weight(m):
103103
constructor_args=(1,),
104104
input_size=(10, 20),
105105
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
106-
test_cuda = (not TEST_WITH_ROCM)
106+
test_cuda=(not TEST_WITH_ROCM)
107107
),
108108
dict(
109109
module_name='Softmax2d',
110110
input_size=(1, 3, 10, 20),
111111
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, False)),
112-
test_cuda = (not TEST_WITH_ROCM)
112+
test_cuda=(not TEST_WITH_ROCM)
113113
),
114114
dict(
115115
module_name='LogSoftmax',
116116
constructor_args=(1,),
117117
input_size=(10, 20),
118118
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
119-
test_cuda = (not TEST_WITH_ROCM)
119+
test_cuda=(not TEST_WITH_ROCM)
120120
),
121121
dict(
122122
module_name='LogSoftmax',
123123
constructor_args=(1,),
124124
input_size=(1, 3, 10, 20),
125125
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
126126
desc='multiparam',
127-
test_cuda = (not TEST_WITH_ROCM)
127+
test_cuda=(not TEST_WITH_ROCM)
128128
),
129129
dict(
130130
module_name='ELU',
@@ -204,61 +204,59 @@ def get_weight(m):
204204
input_size=(2, 3, 4),
205205
desc='1d_multiparam',
206206
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
207-
test_cuda = (not TEST_WITH_ROCM)
207+
test_cuda=(not TEST_WITH_ROCM)
208208
),
209209
dict(
210210
module_name='PReLU',
211211
input_size=(2, 3, 4, 5),
212212
desc='2d',
213213
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
214-
#test_cuda = (not TEST_WITH_ROCM)
215214
),
216215
dict(
217216
module_name='PReLU',
218217
constructor_args=(3,),
219218
input_size=(2, 3, 4, 5),
220219
desc='2d_multiparam',
221220
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
222-
test_cuda = (not TEST_WITH_ROCM)
221+
test_cuda=(not TEST_WITH_ROCM)
223222
),
224223
dict(
225224
module_name='PReLU',
226225
input_size=(2, 3, 4, 5, 6),
227226
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
228227
desc='3d',
229-
#test_cuda = (not TEST_WITH_ROCM)
230228
),
231229
dict(
232230
module_name='PReLU',
233231
constructor_args=(3,),
234232
input_size=(2, 3, 4, 5, 6),
235233
desc='3d_multiparam',
236234
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
237-
test_cuda = (not TEST_WITH_ROCM)
235+
test_cuda=(not TEST_WITH_ROCM)
238236
),
239237
dict(
240238
module_name='Softsign',
241239
input_size=(3, 2, 5),
242240
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
243-
test_cuda = (not TEST_WITH_ROCM)
241+
test_cuda=(not TEST_WITH_ROCM)
244242
),
245243
dict(
246244
module_name='Softmin',
247245
constructor_args=(1,),
248246
input_size=(10, 20),
249-
test_cuda = (not TEST_WITH_ROCM)
247+
test_cuda=(not TEST_WITH_ROCM)
250248
),
251249
dict(
252250
module_name='Softmin',
253251
constructor_args=(1,),
254252
input_size=(2, 3, 5, 10),
255253
desc='multidim',
256-
test_cuda = (not TEST_WITH_ROCM)
254+
test_cuda=(not TEST_WITH_ROCM)
257255
),
258256
dict(
259257
module_name='Tanhshrink',
260258
input_size=(2, 3, 4, 5),
261-
test_cuda = (not TEST_WITH_ROCM)
259+
test_cuda=(not TEST_WITH_ROCM)
262260
),
263261
]
264262

@@ -591,7 +589,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
591589
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
592590
(i.numel() if get_reduction(m) else 1),
593591
check_gradgrad=False,
594-
test_cuda = (not TEST_WITH_ROCM)
592+
test_cuda=(not TEST_WITH_ROCM)
595593
),
596594
dict(
597595
module_name='BCELoss',
@@ -602,7 +600,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
602600
(i.numel() if get_reduction(m) else 1),
603601
desc='weights',
604602
check_gradgrad=False,
605-
test_cuda = (not TEST_WITH_ROCM)
603+
test_cuda=(not TEST_WITH_ROCM)
606604
),
607605
dict(
608606
module_name='CrossEntropyLoss',
@@ -623,7 +621,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
623621
reference_fn=lambda i, t, m:
624622
hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
625623
check_sum_reduction=True,
626-
test_cuda = (not TEST_WITH_ROCM)
624+
test_cuda=(not TEST_WITH_ROCM)
627625
),
628626
dict(
629627
module_name='HingeEmbeddingLoss',
@@ -634,7 +632,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
634632
hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
635633
desc='margin',
636634
check_sum_reduction=True,
637-
test_cuda = (not TEST_WITH_ROCM)
635+
test_cuda=(not TEST_WITH_ROCM)
638636
),
639637
dict(
640638
module_name='MultiLabelMarginLoss',
@@ -661,7 +659,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
661659
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
662660
reference_fn=lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
663661
check_gradgrad=False,
664-
test_cuda = (not TEST_WITH_ROCM)
662+
test_cuda=(not TEST_WITH_ROCM)
665663
),
666664
dict(
667665
module_name='MultiMarginLoss',
@@ -740,7 +738,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
740738
reference_fn=lambda i, t, m:
741739
cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
742740
check_sum_reduction=True,
743-
test_cuda = (not TEST_WITH_ROCM)
741+
test_cuda=(not TEST_WITH_ROCM)
744742
),
745743
dict(
746744
module_name='CosineEmbeddingLoss',
@@ -751,7 +749,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
751749
cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
752750
desc='margin',
753751
check_sum_reduction=True,
754-
test_cuda = (not TEST_WITH_ROCM)
752+
test_cuda=(not TEST_WITH_ROCM)
755753
),
756754
dict(
757755
module_name='MarginRankingLoss',
@@ -760,7 +758,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
760758
reference_fn=lambda i, t, m:
761759
marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
762760
check_sum_reduction=True,
763-
test_cuda = (not TEST_WITH_ROCM)
761+
test_cuda=(not TEST_WITH_ROCM)
764762
),
765763
dict(
766764
module_name='MarginRankingLoss',
@@ -771,7 +769,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
771769
marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
772770
desc='margin',
773771
check_sum_reduction=True,
774-
test_cuda = (not TEST_WITH_ROCM)
772+
test_cuda=(not TEST_WITH_ROCM)
775773
),
776774
]
777775

0 commit comments

Comments
 (0)