diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 2bd0d1878c..a0886dd898 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -14,7 +14,7 @@ ) from torchao.utils import TORCH_VERSION_AFTER_2_4, is_fbcode -class TestModel(nn.Module): +class ToyModel(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(128, 256, bias=False) @@ -36,7 +36,7 @@ def test_runtime_weight_sparsification(self): from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = TestModel().half().cuda() + model = ToyModel().half().cuda() model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -77,7 +77,7 @@ def test_runtime_weight_sparsification_compile(self): from torch.sparse import SparseSemiStructuredTensorCUSPARSELT input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = TestModel().half().cuda() + model = ToyModel().half().cuda() model_c = copy.deepcopy(model) for name, mod in model.named_modules():