diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py index 1c79ed9b23..89f8cf869e 100644 --- a/test/prototype/test_blockwise_triton.py +++ b/test/prototype/test_blockwise_triton.py @@ -41,7 +41,7 @@ def test_blockwise_quant_dequant(_, N, K, dtype): x = torch.randn(N, K).cuda() qx, s = fp8_blockwise_weight_quant(x, dtype=dtype) x_reconstructed = fp8_blockwise_weight_dequant(qx, s) - error = torch.norm(x - x_reconstructed) / torch.norm(x) + error = torch.linalg.vector_norm(x - x_reconstructed) / torch.linalg.vector_norm(x) print(f"Relative Error: {error.item():.6f}") assert error < 0.1, "Quant-Dequant error is too high" @@ -66,7 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype): A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) - error = torch.norm(C - C_q) / torch.norm(C) + error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C) print(f"Relative Error: {error.item():.6f}") assert error < 0.1, "Quantize gemm error is too high" diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 264c70abb6..702e4572d0 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -213,7 +213,9 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap): def snr(ref, actual): error = actual - ref - return 20 * torch.log10(ref.norm() / error.norm()) + return 20 * torch.log10( + torch.linalg.vector_norm(ref) / torch.linalg.vector_norm(error) + ) assert snr(outputs_ref, outputs_int8mp) > 20 assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20 diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 625fb29235..5cb93ac0a0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -144,8 +144,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x: The original tensor. y: The tensor to compare to the original tensor. """ - Ps = torch.norm(x) - Pn = torch.norm(x - y) + Ps = torch.linalg.vector_norm(x) + Pn = torch.linalg.vector_norm(x - y) return 20 * torch.log10(Ps / Pn) diff --git a/torchao/prototype/parq/quant/lsbq.py b/torchao/prototype/parq/quant/lsbq.py index 2d9f4e4c1e..0154f3c543 100644 --- a/torchao/prototype/parq/quant/lsbq.py +++ b/torchao/prototype/parq/quant/lsbq.py @@ -70,7 +70,7 @@ def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool = r = r.sub(v * binary_sign(r)) # compute least squares error, then select the `v` minimizes it - costs = r.norm(dim=dim) + costs = torch.linalg.vector_norm(r, dim=dim) indices = costs.argmin(dim=dim, keepdim=True) v_best = v_cands.gather(1, indices) return v_best @@ -196,10 +196,10 @@ def quantize_optimal_2bits( V1V2.append((v1, v2)) assert len(V1V2) > 0, "LSBQ 2-bit optimal: No solution found." # find the best solution with least-square quantization error - min_error = p.norm() + min_error = torch.linalg.vector_norm(p) for v1v2 in V1V2: r = binary_quant_residue(p, v1v2) - error = r.norm() + error = torch.linalg.vector_norm(r) if error < min_error: min_error = error q = p - r @@ -244,14 +244,14 @@ def quantize_optimal_ternary( v_feasible.append(v) assert len(v_feasible) > 0, "LSBQ ternary optimal: No solution found." # find the best solution with least-square quantization error - min_error = p.norm() + min_error = torch.linalg.vector_norm(p) q_best = torch.zeros_like(p) v_best = torch.zeros_like(v) for v in v_feasible: Q = v * torch.tensor([-1.0, 0.0, 1.0], device=p.device) boundaries = v * torch.tensor([-0.5, 0.5], device=p.device) q = Q[torch.bucketize(p, boundaries)] - error = torch.linalg.norm(p - q) + error = torch.linalg.vector_norm(p - q) if error < min_error: min_error = error q_best = q diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py index ca81ce0453..201dc30f27 100644 --- a/torchao/prototype/quantization/codebook/codebook_ops.py +++ b/torchao/prototype/quantization/codebook/codebook_ops.py @@ -198,8 +198,8 @@ def choose_qparams_codebook( dim=(-1), keepdim=True ).values # Shape: [*input_size[:-1], num_scale_blocks, 1] else: - scales = input.norm( - dim=(-1), keepdim=True + scales = torch.linalg.vector_norm( + input, dim=-1, keepdim=True ) # Shape: [*input_size[:-1], num_scale_blocks, 1] scales = torch.clamp(scales, min=1e-9) @@ -228,12 +228,14 @@ def _kmeans_greedy_init(data: torch.Tensor, k: int) -> torch.Tensor: running_min_distances = torch.full( (data.shape[0],), torch.inf, device=data.device, dtype=data.dtype ) - data_norm_squared = data.norm(p=2, dim=1).square() + data_norm_squared = torch.linalg.vector_norm(data, dim=1).square() for i in range(k): clusters[i] = data[running_min_distances.argmax()] distances_to_cluster_i = ( - data_norm_squared - 2 * data @ clusters[i] + clusters[i].norm().square() + data_norm_squared + - 2 * data @ clusters[i] + + torch.linalg.vector_norm(clusters[i]).square() ) running_min_distances = torch.minimum( running_min_distances, distances_to_cluster_i, out=running_min_distances diff --git a/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py index c61a00b8e1..df9ed7cf5e 100644 --- a/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py +++ b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py @@ -43,7 +43,7 @@ def update_mask(self, module, tensor_name, **kwargs): ) # take norm over all but first dim dims = tuple(range(1, weights.dim())) - saliency = weights.norm(dim=dims, p=1) + saliency = torch.linalg.vector_norm(weights, dim=dims, ord=1) # handle weights in 4 groups split_size = len(mask) // 4 diff --git a/torchao/prototype/sparsity/pruner/saliency_pruner.py b/torchao/prototype/sparsity/pruner/saliency_pruner.py index 5021bfca0d..4619773313 100644 --- a/torchao/prototype/sparsity/pruner/saliency_pruner.py +++ b/torchao/prototype/sparsity/pruner/saliency_pruner.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import torch + from .base_structured_sparsifier import BaseStructuredSparsifier @@ -26,7 +28,9 @@ def update_mask(self, module, tensor_name, **kwargs): raise Exception( "Structured pruning can only be applied to a 2+dim weight tensor!" ) - saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + saliency = -torch.linalg.vector_norm( + weights, dim=tuple(range(1, weights.dim())), ord=1 + ) assert saliency.shape == mask.shape num_to_pick = int(len(mask) * kwargs["sparsity_level"]) diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 24c0808a02..916fff6cd4 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -80,7 +80,7 @@ def forward(self, x_orig): new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0] y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - norm = torch.norm(y, dim=1) ** 2 + norm = torch.linalg.vector_norm(y, dim=1) ** 2 if self.norm.numel() == 0: self.norm.resize_(norm.shape)