diff --git a/src/lib.rs b/src/lib.rs index 661b67b..45428f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -281,40 +281,87 @@ impl Complex { /// /// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`. #[inline] - pub fn sqrt(self) -> Self { - if self.im.is_zero() { - if self.re.is_sign_positive() { - // simple positive real √r, and copy `im` for its sign - Self::new(self.re.sqrt(), self.im) + pub fn sqrt(mut self) -> Self { + // complex sqrt algorithm based on the algorithm from + // dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks + // to increase accuracy. Compared to a naive implementationt that + // reuses the complex exp/ln implementations this algorithm has better + // accuarcy since both (real) sqrt and (real) hypot are garunteed to + // round perfectly. It's also faster since this implementation requires + // less transcendental functions and those it does use (sqrt/hypto) are + // faster comparted to exp/sin/cos. + // + // The musl libc implementation was referenced while implementing the + // algorithm here: + // https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c + + // TODO: rounding for very tiny subnormal numbers isn't perfect yet so + // the assert shown fails in the very worst case this leads to about + // 10% accuracy loss (see example below). As the magnitude increase the + // error quickly drops to basically zero. + // + // glibc handles that (but other implementations like musl and numpy do + // not) by upscaling very small values. That upscaling (and particularly + // it's reversal) are weird and hard to understand (and rely on mantissa + // bit size which we can't get out of the trait). In general the glibc + // implementation is ever so subtley different and I wouldn't want to + // introduce bugs by trying to adapt the underflow handling. + // + // assert_eq!( + // Complex64::new(5.212e-324, 5.212e-324).sqrt(), + // Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162) + // ); + + // specical cases for correct nan/inf handling + // see https://en.cppreference.com/w/c/numeric/complex/csqrt + + if self.re.is_zero() && self.im.is_zero() { + // 0 +/- 0 i + return Self::new(T::zero(), self.im); + } + if self.im.is_infinite() { + // inf +/- inf i + return Self::new(T::infinity(), self.im); + } + if self.re.is_nan() { + // nan + nan i + return Self::new(self.re, T::nan()); + } + if self.re.is_infinite() { + // √(inf +/- NaN i) = inf +/- NaN i + // √(inf +/- x i) = inf +/- 0 i + // √(-inf +/- NaN i) = NaN +/- inf i + // √(-inf +/- x i) = 0 +/- inf i + + // if im is inf (or nan) this is nan, otherwise it's zero + #[allow(clippy::eq_op)] + let zero_or_nan = self.im - self.im; + if self.re.is_sign_negative() { + return Self::new(zero_or_nan.abs(), self.re.copysign(self.im)); } else { - // √(r e^(iπ)) = √r e^(iπ/2) = i√r - // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r - let re = T::zero(); - let im = (-self.re).sqrt(); - if self.im.is_sign_positive() { - Self::new(re, im) - } else { - Self::new(re, -im) - } - } - } else if self.re.is_zero() { - // √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2) - // √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2) - let one = T::one(); - let two = one + one; - let x = (self.im.abs() / two).sqrt(); - if self.im.is_sign_positive() { - Self::new(x, x) - } else { - Self::new(x, -x) + return Self::new(self.re, zero_or_nan.copysign(self.im)); } + } + let two = T::one() + T::one(); + let four = two + two; + let overflow = T::max_value() / (T::one() + T::sqrt(two)); + let max_magnitude = self.re.abs().max(self.im.abs()); + let scale = max_magnitude >= overflow; + if scale { + self = self / four; + } + if self.re.is_sign_negative() { + let tmp = ((-self.re + self.norm()) / two).sqrt(); + self.re = self.im.abs() / (two * tmp); + self.im = tmp.copysign(self.im); } else { - // formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) - let one = T::one(); - let two = one + one; - let (r, theta) = self.to_polar(); - Self::from_polar(r.sqrt(), theta / two) + self.re = ((self.re + self.norm()) / two).sqrt(); + self.im = self.im / (two * self.re); + } + if scale { + self = self * two; } + self } /// Computes the principal value of the cube root of `self`. @@ -2065,6 +2112,164 @@ pub(crate) mod test { } } + #[test] + fn test_sqrt_nan() { + assert!(close_naninf( + Complex64::new(f64::INFINITY, f64::NAN).sqrt(), + Complex64::new(f64::INFINITY, f64::NAN), + )); + assert!(close_naninf( + Complex64::new(f64::NAN, f64::INFINITY).sqrt(), + Complex64::new(f64::INFINITY, f64::INFINITY), + )); + assert!(close_naninf( + Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(), + Complex64::new(f64::NAN, f64::NEG_INFINITY), + )); + assert!(close_naninf( + Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(), + Complex64::new(f64::NAN, f64::INFINITY), + )); + assert!(close_naninf( + Complex64::new(-0.0, 0.0).sqrt(), + Complex64::new(0.0, 0.0), + )); + for x in (-100..100).map(f64::from) { + assert!(close_naninf( + Complex64::new(x, f64::INFINITY).sqrt(), + Complex64::new(f64::INFINITY, f64::INFINITY), + )); + assert!(close_naninf( + Complex64::new(f64::NAN, x).sqrt(), + Complex64::new(f64::NAN, f64::NAN), + )); + // √(inf + x i) = inf + 0 i + assert!(close_naninf( + Complex64::new(f64::INFINITY, x).sqrt(), + Complex64::new(f64::INFINITY, 0.0.copysign(x)), + )); + // √(-inf + x i) = 0 + inf i + assert!(close_naninf( + Complex64::new(f64::NEG_INFINITY, x).sqrt(), + Complex64::new(0.0, f64::INFINITY.copysign(x)), + )); + } + } + + + fn test_sqrt_rounding() { + fn naive_sqrt(c: Complex64) -> Complex64 { + let (r, theta) = c.to_polar(); + Complex64::from_polar(r.sqrt(), theta / 2.0) + } + + fn ulp_l1(a: Complex64, b: Complex64) -> u64 { + let re_ulp = a.re.to_bits().abs_diff(b.re.to_bits()); + let im_ulp = a.im.to_bits().abs_diff(b.im.to_bits()); + re_ulp + im_ulp + } + fn close_to_ulp(a: Complex64, b: Complex64, ulp: usize) -> bool { + ulp_l1(a, b) <= ulp as u64 + } + + #[track_caller] + fn check_sqrt(re: f64, im: f64, exact_sqrt_re: f64, exact_sqrt_im: f64) { + let sqrt = Complex::new(re, im).sqrt(); + assert_eq!(sqrt, Complex::new(exact_sqrt_re, exact_sqrt_im)); + let naive_sqrt = naive_sqrt(Complex::new(re, im)); + assert_ne!(naive_sqrt, sqrt, "invalid testcase {re} {im}"); + let roundtrip = sqrt * sqrt; + let naive_roundtrip = naive_sqrt * naive_sqrt; + assert!( + ulp_l1(roundtrip, Complex::new(re, im)) + <= ulp_l1(naive_roundtrip, Complex::new(re, im)), + "{} {} {}", + Complex::new(re, im), + roundtrip, + naive_roundtrip + ) + } + + #[track_caller] + fn check_sqrt_roundtrip(re: f64, im: f64, ulp: usize) { + let sqrt = Complex::new(re, im).sqrt(); + let roundtrip = sqrt * sqrt; + assert!( + close_to_ulp(roundtrip, Complex::new(re, im), ulp), + "roundtrip failed for {re} + j{im}: {roundtrip}" + ); + let naive_sqrt = naive_sqrt(Complex::new(re, im)); + let naive_roundtrip = naive_sqrt * naive_sqrt; + assert!( + !close_to_ulp(naive_roundtrip, Complex::new(re, im), ulp), + "invalid testcase {re} + j{im} {naive_roundtrip} {roundtrip}" + ); + } + + // some hand-collected testcases that roundtrip perfectly with a + // sophisticated sqrt implementation but not a naive one .This can + // look a bit cherry picked (and it is) but during all my cherry + // picking i didn't find a single case which had worse rounding + check_sqrt_roundtrip(-1e200, 1e100, 0); + // with naive implementation there is an error in both re and im part + // but with the implementation here only on the re part + check_sqrt_roundtrip(1.0 / 3.0, 1.0 / 3.0, 1); + check_sqrt_roundtrip(-1.0 / 3.0, 1.0 / 3.0, 1); + check_sqrt_roundtrip(-0.2, 0.1, 1); + check_sqrt_roundtrip(-0.45, 0.1, 1); + check_sqrt_roundtrip(-std::f64::consts::TAU, std::f64::consts::PI, 1); + // both algorithms don't have the strongest showing here (8 ulp vs 9) but + // 0.0999999999999999-0.45i instead of 0.10000000000000012-0.45000000000000007i + // seems much better since the error is only in the re (and not im) + check_sqrt_roundtrip(0.1, -0.45, 8); + + // reference values were computed with numpy but are identical + // with musl and glibc, showing that we round correctly both + // in reasonable ranges and extremes cases. All of these tests + // fail with a naive sqrt implementation based on phase shift (this + // is checked as part of the tests). + // + // The testcases were generated by the following python script: + // + // import numpy as np + // vals = [ + // (0.1, 0.1), + // (0.1, 1 / 3), + // (1 / 3, 0.1), + // (1 / 3, 1 / 3), + // (1.1, 1e-100), + // (1e-100, 0.1), + // (1e-100, 1.1), + // (1e-100, 1e-100), + // ] + // for re, im in vals: + // reference = np.sqrt(re + im * 1j) + // print(f"check_sqrt({re}, {im}, {reference.real}, {reference.imag});") + + check_sqrt(0.1, 0.1, 0.34743442276011566, 0.14391204994250742); + check_sqrt( + 0.1, + 0.3333333333333333, + 0.4732917794361556, + 0.3521435907152684, + ); + check_sqrt( + 0.3333333333333333, + 0.1, + 0.5836709476652998, + 0.08566470577300687, + ); + check_sqrt( + 0.3333333333333333, + 0.3333333333333333, + 0.6343255686650054, + 0.26274625350107117, + ); + check_sqrt(1.1, 1e-100, 1.0488088481701516, 4.767312946227961e-101); + check_sqrt(1e-100, 1e-100, 1.09868411346781e-50, 4.550898605622274e-51); + check_sqrt(0.1, -0.45, 0.5296117553758811, -0.4248395125601222); + } + #[test] fn test_cbrt() { assert!(close(_0_0i.cbrt(), _0_0i));