Skip to content

Commit 7777565

Browse files
jiayixu64facebook-github-bot
authored andcommitted
Enable Serving Calibration metric to visualize segment calibration of serving traffic.
Summary: * Pull Request resolved: #2201 * Problem: * Calibration metrics will show losses of full data after data consolidation, where the Calibration of each serving traffic will not be visible. * Solution: * Enable segment Calibration visualization to plot the Calibration across examples for each serving task separately, instead of on total volume of the consolidated data. * Enable on APS, follow the implementation for PyPer D49698301. * Usage: Add `SERVING_CALIBRATION` and task indices in `rec_metrics` of the model config. {F1741788763} Differential Revision: D59296724 fbshipit-source-id: a92c12af915c728cb49d6bbcf7925359e255d9a6
1 parent 6bd5c4e commit 7777565

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed

torchrec/metrics/metric_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from torchrec.metrics.recall_session import RecallSessionMetric
4949
from torchrec.metrics.scalar import ScalarMetric
5050
from torchrec.metrics.segmented_ne import SegmentedNEMetric
51+
from torchrec.metrics.serving_calibration import ServingCalibrationMetric
5152
from torchrec.metrics.serving_ne import ServingNEMetric
5253
from torchrec.metrics.throughput import ThroughputMetric
5354
from torchrec.metrics.tower_qps import TowerQPSMetric
@@ -78,6 +79,7 @@
7879
RecMetricEnum.PRECISION: PrecisionMetric,
7980
RecMetricEnum.RECALL: RecallMetric,
8081
RecMetricEnum.SERVING_NE: ServingNEMetric,
82+
RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric,
8183
}
8284

8385

torchrec/metrics/metrics_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class RecMetricEnum(RecMetricEnumBase):
4040
PRECISION = "precision"
4141
RECALL = "recall"
4242
SERVING_NE = "serving_ne"
43+
SERVING_CALIBRATION = "serving_calibration"
4344

4445

4546
@dataclass(unsafe_hash=True, eq=True)

torchrec/metrics/metrics_namespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class MetricName(MetricNameBase):
7070
RECALL = "recall"
7171

7272
SERVING_NE = "serving_ne"
73+
SERVING_CALIBRATION = "serving_calibration"
7374

7475

7576
class MetricNamespaceBase(StrValueMixin, Enum):
@@ -109,6 +110,7 @@ class MetricNamespace(MetricNamespaceBase):
109110
RECALL = "recall"
110111

111112
SERVING_NE = "serving_ne"
113+
SERVING_CALIBRATION = "serving_calibration"
112114

113115

114116
class MetricPrefix(StrValueMixin, Enum):
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Any, cast, Dict, List, Optional, Type
9+
10+
import torch
11+
from torchrec.metrics.calibration import compute_calibration, get_calibration_states
12+
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
13+
from torchrec.metrics.rec_metric import (
14+
MetricComputationReport,
15+
RecMetric,
16+
RecMetricComputation,
17+
RecMetricException,
18+
)
19+
20+
CALIBRATION_NUM = "calibration_num"
21+
CALIBRATION_DENOM = "calibration_denom"
22+
NUM_EXAMPLES = "num_examples"
23+
24+
25+
class ServingCalibrationMetricComputation(RecMetricComputation):
26+
r"""
27+
This class implements the RecMetricComputation for Calibration, which is the
28+
ratio between the prediction and the labels (conversions).
29+
30+
The constructor arguments are defined in RecMetricComputation.
31+
See the docstring of RecMetricComputation for more detail.
32+
"""
33+
34+
def __init__(self, *args: Any, **kwargs: Any) -> None:
35+
super().__init__(*args, **kwargs)
36+
self._add_state(
37+
CALIBRATION_NUM,
38+
torch.zeros(self._n_tasks, dtype=torch.double),
39+
add_window_state=True,
40+
dist_reduce_fx="sum",
41+
persistent=True,
42+
)
43+
self._add_state(
44+
CALIBRATION_DENOM,
45+
torch.zeros(self._n_tasks, dtype=torch.double),
46+
add_window_state=True,
47+
dist_reduce_fx="sum",
48+
persistent=True,
49+
)
50+
self._add_state(
51+
NUM_EXAMPLES,
52+
torch.zeros(self._n_tasks, dtype=torch.long),
53+
add_window_state=False,
54+
dist_reduce_fx="sum",
55+
persistent=True,
56+
)
57+
58+
def update(
59+
self,
60+
*,
61+
predictions: Optional[torch.Tensor],
62+
labels: torch.Tensor,
63+
weights: Optional[torch.Tensor],
64+
**kwargs: Dict[str, Any],
65+
) -> None:
66+
if predictions is None or weights is None:
67+
raise RecMetricException(
68+
"Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update"
69+
)
70+
num_samples = predictions.shape[-1]
71+
for state_name, state_value in get_calibration_states(
72+
labels, predictions, weights
73+
).items():
74+
state = getattr(self, state_name)
75+
state += state_value
76+
self._aggregate_window_state(state_name, state_value, num_samples)
77+
78+
num_examples_delta = torch.count_nonzero(weights, dim=-1)
79+
state_num_examples = getattr(self, NUM_EXAMPLES)
80+
state_num_examples += num_examples_delta
81+
82+
def _compute(self) -> List[MetricComputationReport]:
83+
return [
84+
MetricComputationReport(
85+
name=MetricName.CALIBRATION,
86+
metric_prefix=MetricPrefix.LIFETIME,
87+
value=compute_calibration(
88+
cast(torch.Tensor, self.calibration_num),
89+
cast(torch.Tensor, self.calibration_denom),
90+
),
91+
),
92+
MetricComputationReport(
93+
name=MetricName.CALIBRATION,
94+
metric_prefix=MetricPrefix.WINDOW,
95+
value=compute_calibration(
96+
self.get_window_state(CALIBRATION_NUM),
97+
self.get_window_state(CALIBRATION_DENOM),
98+
),
99+
),
100+
MetricComputationReport(
101+
name=MetricName.TOTAL_EXAMPLES,
102+
metric_prefix=MetricPrefix.DEFAULT,
103+
value=cast(torch.Tensor, self.num_examples).detach(),
104+
),
105+
]
106+
107+
108+
class ServingCalibrationMetric(RecMetric):
109+
_namespace: MetricNamespace = MetricNamespace.SERVING_CALIBRATION
110+
_computation_class: Type[RecMetricComputation] = ServingCalibrationMetricComputation
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
# pyre-strict
5+
6+
import unittest
7+
from typing import Dict, Type
8+
9+
import torch
10+
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
11+
from torchrec.metrics.serving_calibration import ServingCalibrationMetric
12+
from torchrec.metrics.test_utils import (
13+
metric_test_helper,
14+
rec_metric_value_test_launcher,
15+
TestMetric,
16+
)
17+
18+
19+
class TestServingCalibrationMetric(TestMetric):
20+
@staticmethod
21+
def _get_states(
22+
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor
23+
) -> Dict[str, torch.Tensor]:
24+
calibration_num = torch.sum(predictions * weights)
25+
calibration_denom = torch.sum(labels * weights)
26+
num_samples = torch.tensor(labels.size()[0]).double()
27+
return {
28+
"calibration_num": calibration_num,
29+
"calibration_denom": calibration_denom,
30+
"num_samples": num_samples,
31+
}
32+
33+
@staticmethod
34+
def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
35+
return torch.where(
36+
states["calibration_denom"] <= 0.0,
37+
0.0,
38+
states["calibration_num"] / states["calibration_denom"],
39+
).double()
40+
41+
42+
WORLD_SIZE = 4
43+
44+
45+
class ServingCalibrationMetricTest(unittest.TestCase):
46+
clazz: Type[RecMetric] = ServingCalibrationMetric
47+
task_name: str = "calibration"
48+
49+
def test_unfused_calibration(self) -> None:
50+
rec_metric_value_test_launcher(
51+
target_clazz=ServingCalibrationMetric,
52+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
53+
test_clazz=TestServingCalibrationMetric,
54+
metric_name=ServingCalibrationMetricTest.task_name,
55+
task_names=["t1", "t2", "t3"],
56+
fused_update_limit=0,
57+
compute_on_all_ranks=False,
58+
should_validate_update=False,
59+
world_size=WORLD_SIZE,
60+
entry_point=metric_test_helper,
61+
)
62+
63+
def test_fused_calibration(self) -> None:
64+
rec_metric_value_test_launcher(
65+
target_clazz=ServingCalibrationMetric,
66+
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
67+
test_clazz=TestServingCalibrationMetric,
68+
metric_name=ServingCalibrationMetricTest.task_name,
69+
task_names=["t1", "t2", "t3"],
70+
fused_update_limit=0,
71+
compute_on_all_ranks=False,
72+
should_validate_update=False,
73+
world_size=WORLD_SIZE,
74+
entry_point=metric_test_helper,
75+
)

0 commit comments

Comments
 (0)