Skip to content

Commit d906910

Browse files
committed
elif
1 parent 15f4e1a commit d906910

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

references/classification/transforms.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
4040
"""
4141
if batch.ndim != 4:
4242
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
43-
elif target.ndim != 1:
43+
if target.ndim != 1:
4444
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
45-
elif not batch.is_floating_point():
45+
if not batch.is_floating_point():
4646
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
47-
elif target.dtype != torch.int64:
47+
if target.dtype != torch.int64:
4848
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
4949

5050
if not self.inplace:
@@ -116,11 +116,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
116116
"""
117117
if batch.ndim != 4:
118118
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
119-
elif target.ndim != 1:
119+
if target.ndim != 1:
120120
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
121-
elif not batch.is_floating_point():
121+
if not batch.is_floating_point():
122122
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
123-
elif target.dtype != torch.int64:
123+
if target.dtype != torch.int64:
124124
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
125125

126126
if not self.inplace:

0 commit comments

Comments
 (0)