|
| 1 | +import math |
| 2 | +from functools import partial |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import Tensor |
| 7 | + |
| 8 | +from ..utils import channel_bucketize |
| 9 | +from .proxmap import ProxMap |
| 10 | + |
| 11 | + |
| 12 | +def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None): |
| 13 | + try: |
| 14 | + return partial( |
| 15 | + torch.amp.custom_fwd, device_type="cuda", cast_inputs=cast_inputs |
| 16 | + ) |
| 17 | + except AttributeError: |
| 18 | + return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs) |
| 19 | + |
| 20 | + |
| 21 | +def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float: |
| 22 | + """Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2). |
| 23 | + s is steepness of the sigmoid-like function, almost linear for s < 1. |
| 24 | + 'mirror' means decreasing instead of increasing as true sigmoid, |
| 25 | + 'normalized' means value 1 at starting point t1 and 0 at end point t2.""" |
| 26 | + assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2" |
| 27 | + ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2 |
| 28 | + st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid |
| 29 | + s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0 |
| 30 | + s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1 |
| 31 | + return (st - s2) / (s1 - s2) # shift and scale to range (0, 1] |
| 32 | + |
| 33 | + |
| 34 | +class ProxPARQ(ProxMap): |
| 35 | + def __init__( |
| 36 | + self, anneal_start: int, anneal_end: int, steepness: float = 10 |
| 37 | + ) -> None: |
| 38 | + assert anneal_start < anneal_end, "PARQ annealing: start before end." |
| 39 | + assert steepness > 0, "PARQ annealing steepness should be positive." |
| 40 | + self.anneal_start = anneal_start |
| 41 | + self.anneal_end = anneal_end |
| 42 | + self.steepness = steepness |
| 43 | + |
| 44 | + @torch.no_grad() |
| 45 | + @amp_custom_fwd(cast_inputs=torch.float32) |
| 46 | + def apply_( |
| 47 | + self, |
| 48 | + p: Tensor, |
| 49 | + q: Tensor, |
| 50 | + Q: Tensor, |
| 51 | + step_count: int, |
| 52 | + dim: Optional[int] = None, |
| 53 | + ) -> float: |
| 54 | + """Prox-map of PARQ with gradual annealing to hard quantization.""" |
| 55 | + |
| 56 | + if step_count < self.anneal_start: |
| 57 | + inv_slope = 1.0 |
| 58 | + elif step_count >= self.anneal_end: |
| 59 | + inv_slope = 0.0 |
| 60 | + if q is None: |
| 61 | + # hard quantization to the nearest point in Q |
| 62 | + Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 |
| 63 | + if dim is None: |
| 64 | + q = Q[torch.bucketize(p, Q_mid)] |
| 65 | + else: |
| 66 | + q = Q.gather(1, channel_bucketize(p, Q_mid)) |
| 67 | + p.copy_(q) |
| 68 | + else: |
| 69 | + inv_slope = normalized_mirror_sigmoid( |
| 70 | + step_count, self.anneal_start, self.anneal_end, self.steepness |
| 71 | + ) |
| 72 | + # it is important to clamp idx-1 and then clamping idx itself |
| 73 | + # idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min() |
| 74 | + if dim is None: |
| 75 | + idx = torch.bucketize(p, Q) # locate quant interval |
| 76 | + idx_lower = (idx - 1).clamp_(min=0) # index of lower bound |
| 77 | + idx_upper = idx.clamp(max=Q.numel() - 1) # index of upper bound |
| 78 | + q_lower = Q[idx_lower] # lower boundary of interval |
| 79 | + q_upper = Q[idx_upper] # upper boundary of interval |
| 80 | + center = (q_lower + q_upper) / 2 # center of interval |
| 81 | + # concise implementation of piecewise-affine prox map |
| 82 | + q = (center + (p - center) / inv_slope).clamp_(min=q_lower, max=q_upper) |
| 83 | + else: |
| 84 | + idx = channel_bucketize(p, Q) |
| 85 | + idx_lower = (idx - 1).clamp_(min=0) |
| 86 | + idx_upper = idx.clamp(max=Q.size(1) - 1) |
| 87 | + q_lower = Q.gather(1, idx_lower) |
| 88 | + q_upper = Q.gather(1, idx_upper) |
| 89 | + center = (q_lower + q_upper) / 2 |
| 90 | + q = torch.minimum( |
| 91 | + torch.maximum(center + (p - center) / inv_slope, q_lower), q_upper |
| 92 | + ) |
| 93 | + # in-place update of model parameters |
| 94 | + p.copy_(q) |
| 95 | + return inv_slope |
0 commit comments