diff --git a/KernelBench/level1/92_cumsum_exclusive.py b/KernelBench/level1/92_cumsum_exclusive.py index 0c819da7..aebefb4d 100644 --- a/KernelBench/level1/92_cumsum_exclusive.py +++ b/KernelBench/level1/92_cumsum_exclusive.py @@ -14,7 +14,7 @@ def __init__(self, dim): self.dim = dim def forward(self, x): - exclusive_cumsum = torch.cat((torch.zeros_like(x.select(self.dim, 0).unsqueeze(self.dim)), x), dim=self.dim)[:-1] + exclusive_cumsum = torch.cat((torch.zeros_like(x.select(self.dim, 0).unsqueeze(self.dim)), x), dim=self.dim).narrow(self.dim, 0, x.size(dim)) return torch.cumsum(exclusive_cumsum, dim=self.dim) batch_size = 32768