|
| 1 | +#ifndef NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ |
| 2 | +#define NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ |
| 3 | + |
| 4 | +#include "hwy/highway.h" |
| 5 | + |
| 6 | +#include <cfenv> |
| 7 | +#include <type_traits> |
| 8 | + |
| 9 | +namespace npsr { |
| 10 | + |
| 11 | +struct _NoLargeArgument {}; |
| 12 | +struct _NoSpecialCases {}; |
| 13 | +struct _NoExceptions {}; |
| 14 | +struct _LowAccuracy {}; |
| 15 | +constexpr auto kNoLargeArgument = _NoLargeArgument{}; |
| 16 | +constexpr auto kNoSpecialCases = _NoSpecialCases{}; |
| 17 | +constexpr auto kNoExceptions = _NoExceptions{}; |
| 18 | +constexpr auto kLowAccuracy = _LowAccuracy{}; |
| 19 | + |
| 20 | +struct Round { |
| 21 | + struct _Force {}; |
| 22 | + struct _Nearest {}; |
| 23 | + struct _Down {}; |
| 24 | + struct _Up {}; |
| 25 | + struct _Zero {}; |
| 26 | + static constexpr auto kForce = _Force{}; |
| 27 | + static constexpr auto kNearest = _Nearest{}; |
| 28 | +#if 0 // not used yet |
| 29 | + static constexpr auto kDown = _Down{}; |
| 30 | + static constexpr auto kUp = _Up{}; |
| 31 | + static constexpr auto kZero = _Zero{}; |
| 32 | +#endif |
| 33 | +}; |
| 34 | + |
| 35 | +struct Subnormal { |
| 36 | + struct _DAZ {}; |
| 37 | + struct _FTZ {}; |
| 38 | + struct _IEEE754 {}; |
| 39 | +#if 0 // not used yet |
| 40 | + static constexpr auto kDAZ = _DAZ{}; |
| 41 | + static constexpr auto kFTZ = _FTZ{}; |
| 42 | +#endif |
| 43 | + static constexpr auto kIEEE754 = _IEEE754{}; |
| 44 | +}; |
| 45 | + |
| 46 | +struct FPExceptions { |
| 47 | + static constexpr auto kNone = 0; |
| 48 | + static constexpr auto kInvalid = FE_INVALID; |
| 49 | + static constexpr auto kDivByZero = FE_DIVBYZERO; |
| 50 | + static constexpr auto kOverflow = FE_OVERFLOW; |
| 51 | + static constexpr auto kUnderflow = FE_UNDERFLOW; |
| 52 | +}; |
| 53 | + |
| 54 | +template <typename... Args> class Precise { |
| 55 | +public: |
| 56 | + Precise() { |
| 57 | + if constexpr (!kNoExceptions) { |
| 58 | + fegetexceptflag(&_exceptions, FE_ALL_EXCEPT); |
| 59 | + } |
| 60 | + if constexpr (kRoundForce) { |
| 61 | + _rounding_mode = fegetround(); |
| 62 | + int new_mode = _NewRoundingMode(); |
| 63 | + if (_rounding_mode != new_mode) { |
| 64 | + _retrieve_rounding_mode = true; |
| 65 | + fesetround(new_mode); |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + void FlushExceptions() { fesetexceptflag(&_exceptions, FE_ALL_EXCEPT); } |
| 70 | + |
| 71 | + void Raise(int errors) { |
| 72 | + static_assert(!kNoExceptions, |
| 73 | + "Cannot raise exceptions in NoExceptions mode"); |
| 74 | + _exceptions |= errors; |
| 75 | + } |
| 76 | + ~Precise() { |
| 77 | + FlushExceptions(); |
| 78 | + if constexpr (kRoundForce) { |
| 79 | + if (_retrieve_rounding_mode) { |
| 80 | + fesetround(_rounding_mode); |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + static constexpr bool kNoExceptions = |
| 85 | + (std::is_same_v<_NoExceptions, Args> || ...); |
| 86 | + static constexpr bool kNoLargeArgument = |
| 87 | + (std::is_same_v<_NoLargeArgument, Args> || ...); |
| 88 | + static constexpr bool kNoSpecialCases = |
| 89 | + (std::is_same_v<_NoSpecialCases, Args> || ...); |
| 90 | + static constexpr bool kLowAccuracy = |
| 91 | + (std::is_same_v<_LowAccuracy, Args> || ...); |
| 92 | + // defaults to high accuracy if no low accuracy flag is set |
| 93 | + static constexpr bool kHighAccuracy = !kLowAccuracy; |
| 94 | + // defaults to large argument support if no no large argument flag is set |
| 95 | + static constexpr bool kLargeArgument = !kNoLargeArgument; |
| 96 | + // defaults to special cases support if no no special cases flag is set |
| 97 | + static constexpr bool kSpecialCases = !kNoSpecialCases; |
| 98 | + // defaults to exception support if no no exception flag is set |
| 99 | + static constexpr bool kException = !kNoExceptions; |
| 100 | + |
| 101 | + static constexpr bool kRoundForce = |
| 102 | + (std::is_same_v<Round::_Force, Args> || ...); |
| 103 | + static constexpr bool _kRoundNearest = |
| 104 | + (std::is_same_v<Round::_Nearest, Args> || ...); |
| 105 | + static constexpr bool kRoundZero = |
| 106 | + (std::is_same_v<Round::_Zero, Args> || ...); |
| 107 | + static constexpr bool kRoundDown = |
| 108 | + (std::is_same_v<Round::_Down, Args> || ...); |
| 109 | + static constexpr bool kRoundUp = (std::is_same_v<Round::_Up, Args> || ...); |
| 110 | + // only one rounding mode can be set |
| 111 | + static_assert((_kRoundNearest + kRoundDown + kRoundUp + kRoundZero) <= 1, |
| 112 | + "Only one rounding mode can be set at a time"); |
| 113 | + // if no rounding mode is set, default to round nearest |
| 114 | + static constexpr bool kRoundNearest = |
| 115 | + _kRoundNearest || (!kRoundDown && !kRoundUp && !kRoundZero); |
| 116 | + |
| 117 | + static constexpr bool kDAZ = (std::is_same_v<Subnormal::_DAZ, Args> || ...); |
| 118 | + static constexpr bool kFTZ = (std::is_same_v<Subnormal::_FTZ, Args> || ...); |
| 119 | + static constexpr bool _kIEEE754 = |
| 120 | + (std::is_same_v<Subnormal::_IEEE754, Args> || ...); |
| 121 | + static_assert(!_kIEEE754 || !(kDAZ || kFTZ), |
| 122 | + "IEEE754 mode cannot be used " |
| 123 | + "with Denormals Are Zero (DAZ) or Flush To Zero (FTZ) " |
| 124 | + "subnormal handling"); |
| 125 | + static constexpr bool kIEEE754 = _kIEEE754 || !(kDAZ || kFTZ); |
| 126 | + |
| 127 | +private: |
| 128 | + int _NewRoundingMode() const { |
| 129 | + if constexpr (kRoundDown) { |
| 130 | + return FE_DOWNWARD; |
| 131 | + } else if constexpr (kRoundUp) { |
| 132 | + return FE_UPWARD; |
| 133 | + } else if constexpr (kRoundZero) { |
| 134 | + return FE_TOWARDZERO; |
| 135 | + } else { |
| 136 | + return FE_TONEAREST; |
| 137 | + } |
| 138 | + } |
| 139 | + int _rounding_mode = 0; |
| 140 | + bool _retrieve_rounding_mode = false; |
| 141 | + fexcept_t _exceptions; |
| 142 | +}; |
| 143 | + |
| 144 | +} // namespace npsr |
| 145 | + |
| 146 | +#endif // NUMPY_SIMD_ROUTINES_NPSR_COMMON_H_ |
0 commit comments