Skip to content

Commit 0c70c26

Browse files
authored
Merge pull request #129 from iotamudelta/master
dataloader tests
2 parents d74ee71 + beaf876 commit 0c70c26

File tree

1 file changed

+0
-7
lines changed

1 file changed

+0
-7
lines changed

test/test_dataloader.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,12 @@ def test_growing_dataset(self):
338338
self.assertEqual(len(dataloader_shuffle), 5)
339339

340340
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
341-
@skipIfRocm
342341
def test_sequential_pin_memory(self):
343342
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
344343
for input, target in loader:
345344
self.assertTrue(input.is_pinned())
346345
self.assertTrue(target.is_pinned())
347346

348-
@skipIfRocm
349347
def test_multiple_dataloaders(self):
350348
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
351349
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
@@ -446,7 +444,6 @@ def test_batch_sampler(self):
446444
self._test_batch_sampler(num_workers=4)
447445

448446
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
449-
@skipIfRocm
450447
def test_shuffle_pin_memory(self):
451448
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
452449
for input, target in loader:
@@ -479,7 +476,6 @@ def test_error_workers(self):
479476

480477
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
481478
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
482-
@skipIfRocm
483479
def test_partial_workers(self):
484480
r"""Check that workers exit even if the iterator is not exhausted."""
485481
for pin_memory in (True, False):
@@ -652,7 +648,6 @@ def setUp(self):
652648
self.dataset = StringDataset()
653649

654650
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
655-
@skipIfRocm
656651
def test_shuffle_pin_memory(self):
657652
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
658653
for batch_ndx, (s, n) in enumerate(loader):
@@ -696,7 +691,6 @@ def test_sequential_batch(self):
696691
self.assertEqual(n[1], idx + 1)
697692

698693
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
699-
@skipIfRocm
700694
def test_pin_memory(self):
701695
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
702696
for batch_ndx, sample in enumerate(loader):
@@ -736,7 +730,6 @@ def _run_ind_worker_queue_test(self, batch_size, num_workers):
736730
if current_worker_idx == num_workers:
737731
current_worker_idx = 0
738732

739-
@skipIfRocm
740733
def test_ind_worker_queue(self):
741734
for batch_size in (8, 16, 32, 64):
742735
for num_workers in range(1, 6):

0 commit comments

Comments
 (0)