Skip to content

Replace torch.norm with torch.linalg.vector_norm for PyTorch future update #2660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/test_blockwise_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_blockwise_quant_dequant(_, N, K, dtype):
x = torch.randn(N, K).cuda()
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
error = torch.norm(x - x_reconstructed) / torch.norm(x)
error = torch.linalg.vector_norm(x - x_reconstructed) / torch.linalg.vector_norm(x)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.1, "Quant-Dequant error is too high"
Expand All @@ -66,7 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype):
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
error = torch.norm(C - C_q) / torch.norm(C)
error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.1, "Quantize gemm error is too high"
4 changes: 3 additions & 1 deletion test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap):

def snr(ref, actual):
error = actual - ref
return 20 * torch.log10(ref.norm() / error.norm())
return 20 * torch.log10(
torch.linalg.vector_norm(ref) / torch.linalg.vector_norm(error)
)

assert snr(outputs_ref, outputs_int8mp) > 20
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
Expand Down
4 changes: 2 additions & 2 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x: The original tensor.
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
Ps = torch.linalg.vector_norm(x)
Pn = torch.linalg.vector_norm(x - y)
return 20 * torch.log10(Ps / Pn)


Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/parq/quant/lsbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool =
r = r.sub(v * binary_sign(r))

# compute least squares error, then select the `v` minimizes it
costs = r.norm(dim=dim)
costs = torch.linalg.vector_norm(r, dim=dim)
indices = costs.argmin(dim=dim, keepdim=True)
v_best = v_cands.gather(1, indices)
return v_best
Expand Down Expand Up @@ -196,10 +196,10 @@ def quantize_optimal_2bits(
V1V2.append((v1, v2))
assert len(V1V2) > 0, "LSBQ 2-bit optimal: No solution found."
# find the best solution with least-square quantization error
min_error = p.norm()
min_error = torch.linalg.vector_norm(p)
for v1v2 in V1V2:
r = binary_quant_residue(p, v1v2)
error = r.norm()
error = torch.linalg.vector_norm(r)
if error < min_error:
min_error = error
q = p - r
Expand Down Expand Up @@ -244,14 +244,14 @@ def quantize_optimal_ternary(
v_feasible.append(v)
assert len(v_feasible) > 0, "LSBQ ternary optimal: No solution found."
# find the best solution with least-square quantization error
min_error = p.norm()
min_error = torch.linalg.vector_norm(p)
q_best = torch.zeros_like(p)
v_best = torch.zeros_like(v)
for v in v_feasible:
Q = v * torch.tensor([-1.0, 0.0, 1.0], device=p.device)
boundaries = v * torch.tensor([-0.5, 0.5], device=p.device)
q = Q[torch.bucketize(p, boundaries)]
error = torch.linalg.norm(p - q)
error = torch.linalg.vector_norm(p - q)
if error < min_error:
min_error = error
q_best = q
Expand Down
10 changes: 6 additions & 4 deletions torchao/prototype/quantization/codebook/codebook_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def choose_qparams_codebook(
dim=(-1), keepdim=True
).values # Shape: [*input_size[:-1], num_scale_blocks, 1]
else:
scales = input.norm(
dim=(-1), keepdim=True
scales = torch.linalg.vector_norm(
input, dim=-1, keepdim=True
) # Shape: [*input_size[:-1], num_scale_blocks, 1]
scales = torch.clamp(scales, min=1e-9)

Expand Down Expand Up @@ -228,12 +228,14 @@ def _kmeans_greedy_init(data: torch.Tensor, k: int) -> torch.Tensor:
running_min_distances = torch.full(
(data.shape[0],), torch.inf, device=data.device, dtype=data.dtype
)
data_norm_squared = data.norm(p=2, dim=1).square()
data_norm_squared = torch.linalg.vector_norm(data, dim=1).square()

for i in range(k):
clusters[i] = data[running_min_distances.argmax()]
distances_to_cluster_i = (
data_norm_squared - 2 * data @ clusters[i] + clusters[i].norm().square()
data_norm_squared
- 2 * data @ clusters[i]
+ torch.linalg.vector_norm(clusters[i]).square()
)
running_min_distances = torch.minimum(
running_min_distances, distances_to_cluster_i, out=running_min_distances
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def update_mask(self, module, tensor_name, **kwargs):
)
# take norm over all but first dim
dims = tuple(range(1, weights.dim()))
saliency = weights.norm(dim=dims, p=1)
saliency = torch.linalg.vector_norm(weights, dim=dims, ord=1)

# handle weights in 4 groups
split_size = len(mask) // 4
Expand Down
6 changes: 5 additions & 1 deletion torchao/prototype/sparsity/pruner/saliency_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import torch

from .base_structured_sparsifier import BaseStructuredSparsifier


Expand All @@ -26,7 +28,9 @@ def update_mask(self, module, tensor_name, **kwargs):
raise Exception(
"Structured pruning can only be applied to a 2+dim weight tensor!"
)
saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
saliency = -torch.linalg.vector_norm(
weights, dim=tuple(range(1, weights.dim())), ord=1
)
assert saliency.shape == mask.shape

num_to_pick = int(len(mask) * kwargs["sparsity_level"])
Expand Down
2 changes: 1 addition & 1 deletion torchao/sparsity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x_orig):
new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0]
y = x.permute(new_axis_list)
y = torch.flatten(y, start_dim=1)
norm = torch.norm(y, dim=1) ** 2
norm = torch.linalg.vector_norm(y, dim=1) ** 2

if self.norm.numel() == 0:
self.norm.resize_(norm.shape)
Expand Down