Skip to content

Commit b8a19ed

Browse files
author
Juntian Liu
authored
Add SNR numerical comparator
Differential Revision: D77159515 Pull Request resolved: #11859
1 parent 0c12dcd commit b8a19ed

File tree

8 files changed

+128
-7
lines changed

8 files changed

+128
-7
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,4 +729,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
729729
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
730730
)
731731
input_tensor = input_tensor.detach().cpu().double()
732+
733+
# Convert NaN to 0.0
734+
if torch.isnan(input_tensor).any():
735+
input_tensor = torch.nan_to_num(input_tensor)
732736
return input_tensor

devtools/inspector/numerical_comparator/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@ python_library(
2727
],
2828
)
2929

30+
python_library(
31+
name = "snr_numerical_comparator",
32+
srcs = ["snr_numerical_comparator.py"],
33+
deps = [
34+
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
35+
"//executorch/devtools/inspector:inspector_utils",
36+
],
37+
)
38+
3039
python_library(
3140
name = "lib",
3241
srcs = ["__init__.py"],
3342
deps = [
3443
":l1_numerical_comparator",
3544
":mse_numerical_comparator",
45+
":snr_numerical_comparator",
3646
],
3747
)

devtools/inspector/numerical_comparator/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,9 @@
1313
MSEComparator,
1414
)
1515

16+
from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import (
17+
SNRComparator,
18+
)
19+
1620

17-
__all__ = ["L1Comparator", "MSEComparator"]
21+
__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"]

devtools/inspector/numerical_comparator/l1_numerical_comparator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ def compare(self, a: Any, b: Any) -> float:
1919

2020
t_a = convert_to_float_tensor(a)
2121
t_b = convert_to_float_tensor(b)
22-
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23-
t_a = torch.nan_to_num(t_a)
24-
t_b = torch.nan_to_num(t_b)
2522

2623
try:
2724
res = torch.abs(t_a - t_b).sum().item()

devtools/inspector/numerical_comparator/mse_numerical_comparator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ def compare(self, a: Any, b: Any) -> float:
1919

2020
t_a = convert_to_float_tensor(a)
2121
t_b = convert_to_float_tensor(b)
22-
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23-
t_a = torch.nan_to_num(t_a)
24-
t_b = torch.nan_to_num(t_b)
2522

2623
try:
2724
res = float(torch.mean(torch.square(t_a - t_b)))
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Any
9+
10+
import torch
11+
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
12+
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
13+
NumericalComparatorBase,
14+
)
15+
16+
17+
class SNRComparator(NumericalComparatorBase):
18+
def compare(self, a: Any, b: Any) -> float:
19+
"""
20+
Compare the Signal-to-Noise Ratio (SNR) between two inputs
21+
Formula: SNR = 10 * log10(original_power / error_power)
22+
"""
23+
24+
t_a = convert_to_float_tensor(a)
25+
t_b = convert_to_float_tensor(b)
26+
27+
# Calculate the signal power and noise power
28+
original_power = torch.mean(torch.pow(t_a, 2))
29+
try:
30+
error = t_a - t_b
31+
error_power = torch.mean(torch.pow(error, 2))
32+
except Exception as e:
33+
raise ValueError(
34+
f"Error computing SNR difference between tensors: {str(e)}"
35+
)
36+
37+
# Calculate SNR
38+
snr = 10 * torch.log10(original_power / error_power)
39+
return snr.item()

devtools/inspector/tests/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ python_unittest(
7070
],
7171
)
7272

73+
python_unittest(
74+
name = "snr_comparator_test",
75+
srcs = ["snr_comparator_test.py"],
76+
deps = [
77+
"//executorch/devtools/inspector/numerical_comparator:lib",
78+
],
79+
)
80+
7381
python_library(
7482
name = "inspector_test_utils",
7583
srcs = [
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import unittest
9+
10+
import torch
11+
12+
from executorch.devtools.inspector.numerical_comparator import SNRComparator
13+
14+
15+
class TestSNRComparator(unittest.TestCase):
16+
snr_comparator = SNRComparator()
17+
18+
def test_identical_tensors(self):
19+
# identical tensors --> error_power == 0 --> SNR is inf
20+
a = torch.tensor([[10, 4], [3, 4]])
21+
b = torch.tensor([[10, 4], [3, 4]])
22+
result = self.snr_comparator.compare(a, b)
23+
self.assertTrue(math.isinf(result) and result > 0)
24+
25+
def test_scalar(self):
26+
# original_power == 1, error_power == 1 --> SNR = 10 * log10(1/1) = 0
27+
a = 1
28+
b = 2
29+
result = self.snr_comparator.compare(a, b)
30+
self.assertAlmostEqual(result, 0.0)
31+
32+
def test_with_nans_replaced_with_zero(self):
33+
a = torch.tensor([float("nan"), 1.0])
34+
b = torch.tensor([0.0, 1.0])
35+
result = self.snr_comparator.compare(a, b)
36+
self.assertTrue(math.isinf(result) and result > 0)
37+
38+
def test_shape_mismatch_raises_exception(self):
39+
a = torch.tensor([1, 2, -1])
40+
b = torch.tensor([1, 1, -3, 4])
41+
with self.assertRaises(ValueError):
42+
self.snr_comparator.compare(a, b)
43+
44+
def test_2D_tensors(self):
45+
# original_power = mean([16, 81, 36, 16]) = 37.25
46+
# error = a - b = [3, 7, 3, -1] squared = [9, 49, 9, 1] mean = 68/4 = 17.0
47+
# SNR = 10 * log10(37.25/17.0)
48+
a = torch.tensor([[4, 9], [6, 4]])
49+
b = torch.tensor([[1, 2], [3, 5]])
50+
expected = 10 * math.log10(37.25 / 17.0)
51+
result = self.snr_comparator.compare(a, b)
52+
self.assertAlmostEqual(result, expected)
53+
54+
def test_list_of_tensors(self):
55+
# original_power = mean(4, 16, 25, 4]) = 12.25
56+
# error = a - b = [1, 2, 2, -3] squared = [1, 4, 4, 9] mean = 18/4 = 4.5
57+
# SNR = 10 * log10(37.25/17.0)
58+
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
59+
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
60+
expected = 10 * math.log10(12.25 / 4.5)
61+
result = self.snr_comparator.compare(a, b)
62+
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)