Skip to content

[clang] constexpr built-in fma function. #113020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Non-comprehensive list of changes in this release
- Plugins can now define custom attributes that apply to statements
as well as declarations.
- ``__builtin_abs`` function can now be used in constant expressions.
- ``__builtin_fma`` function can now be used in constant expressions.

New Compiler Flags
------------------
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def FloorF16F128 : Builtin, F16F128MathTemplate {

def FmaF16F128 : Builtin, F16F128MathTemplate {
let Spellings = ["__builtin_fma"];
let Attributes = [FunctionWithBuiltinPrefix, NoThrow, ConstIgnoringErrnoAndExceptions];
let Attributes = [FunctionWithBuiltinPrefix, NoThrow, ConstIgnoringErrnoAndExceptions, Constexpr];
let Prototype = "T(T, T, T)";
}

Expand Down Expand Up @@ -3723,6 +3723,7 @@ def Fma : FPMathTemplate, LibBuiltin<"math.h"> {
let Attributes = [NoThrow, ConstIgnoringErrnoAndExceptions];
let Prototype = "T(T, T, T)";
let AddBuiltinPrefixedAlias = 1;
let OnlyBuiltinPrefixedAliasIsConstexpr = 1;
}

def Fmax : FPMathTemplate, LibBuiltin<"math.h"> {
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/AST/ByteCode/Floating.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ class Floating final {
return R->F.divide(B.F, RM);
}

static APFloat::opStatus fma(const Floating &A, const Floating &B,
const Floating &C, llvm::RoundingMode RM,
Floating *R) {
*R = Floating(A.F);
return R->F.fusedMultiplyAdd(B.F, C.F, RM);
}

static bool neg(const Floating &A, Floating *R) {
*R = -A;
return false;
Expand Down
37 changes: 37 additions & 0 deletions clang/lib/AST/ByteCode/InterpBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,19 @@ static bool retPrimValue(InterpState &S, CodePtr OpPC, APValue &Result,
#undef RET_CASE
}

/// Get rounding mode to use in evaluation of the specified expression.
///
/// If rounding mode is unknown at compile time, still try to evaluate the
/// expression. If the result is exact, it does not depend on rounding mode.
/// So return "tonearest" mode instead of "dynamic".
static llvm::RoundingMode getActiveRoundingMode(InterpState &S, const Expr *E) {
llvm::RoundingMode RM =
E->getFPFeaturesInEffect(S.getLangOpts()).getRoundingMode();
if (RM == llvm::RoundingMode::Dynamic)
RM = llvm::RoundingMode::NearestTiesToEven;
return RM;
}

static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
const InterpFrame *Frame,
const CallExpr *Call) {
Expand Down Expand Up @@ -549,6 +562,22 @@ static bool interp__builtin_fpclassify(InterpState &S, CodePtr OpPC,
return true;
}

static bool interp__builtin_fma(InterpState &S, CodePtr OpPC,
const InterpFrame *Frame, const Function *Func,
const CallExpr *Call) {
const Floating &X = getParam<Floating>(Frame, 0);
const Floating &Y = getParam<Floating>(Frame, 1);
const Floating &Z = getParam<Floating>(Frame, 2);
Floating Result;

llvm::RoundingMode RM = getActiveRoundingMode(S, Call);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have questions about the rounding mode. Is getActiveRoundingMode() trying to account for the FE_ROUND pragma setting? If that pragma isn't used and we aren't compiling in strict mode, this should give us "NearestTiesToEven", right?

But if the rounding mode is dynamic, then I think we need to know if the call is in a constexpr. Consider:

float f1() {
  fesetround(FE_UPWARD);
  constrexpr float x = __builtin_fma(1.0f, 0.0f, 0.1f); // constexpr evaluates at the default rounding mode?
  float y = __builtin_fma(1.0f, 0.1f, 0.1f); // Non-constexpr should use the dynamic rounding mode?
  return x - y;
}

I'm not 100% certain about the language rules here, but at least in C my understanding is that initialization of non-static, non-constexp expressions should be done "as if" evaluated at runtime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static llvm::RoundingMode getActiveRoundingMode(EvalInfo &Info, const Expr *E) {
llvm::RoundingMode RM =
E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()).getRoundingMode();
if (RM == llvm::RoundingMode::Dynamic)
RM = llvm::RoundingMode::NearestTiesToEven;
return RM;
}

I actually obtained this from the old constant evaluator, but I'm not sure if it has the same rounding issue.

if (Floating::fma(X, Y, Z, RM, &Result) != APFloat::opOK)
return false;

S.Stk.push<Floating>(Result);
return true;
}

// The C standard says "fabs raises no floating-point exceptions,
// even if x is a signaling NaN. The returned value is independent of
// the current rounding direction mode." Therefore constant folding can
Expand Down Expand Up @@ -1826,6 +1855,14 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const Function *F,
return false;
break;

case Builtin::BI__builtin_fma:
case Builtin::BI__builtin_fmaf:
case Builtin::BI__builtin_fmal:
case Builtin::BI__builtin_fmaf128:
if (!interp__builtin_fma(S, OpPC, Frame, F, Call))
return false;
break;

case Builtin::BI__builtin_fabs:
case Builtin::BI__builtin_fabsf:
case Builtin::BI__builtin_fabsl:
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15314,6 +15314,20 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
Result.changeSign();
return true;

case Builtin::BI__builtin_fma:
case Builtin::BI__builtin_fmaf:
case Builtin::BI__builtin_fmal:
case Builtin::BI__builtin_fmaf128: {
APFloat Y(0.), Z(0.);
if (!EvaluateFloat(E->getArg(0), Result, Info) ||
!EvaluateFloat(E->getArg(1), Y, Info) ||
!EvaluateFloat(E->getArg(2), Z, Info))
return false;

llvm::RoundingMode RM = getActiveRoundingMode(Info, E);
return Result.fusedMultiplyAdd(Y, Z, RM) == APFloat::opOK;
}

case Builtin::BI__arithmetic_fence:
return EvaluateFloat(E->getArg(0), Result, Info);

Expand Down
9 changes: 9 additions & 0 deletions clang/test/AST/ByteCode/builtin-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,15 @@ namespace fpclassify {
char classify_subnorm [__builtin_fpclassify(-1, -1, -1, +1, -1, 1.0e-38f)];
}

namespace fma {
static_assert(__builtin_fma(1.0, 1.0, 1.0) == 2.0);
static_assert(__builtin_fma(1.0, -1.0, 1.0) == 0.0);
static_assert(__builtin_fmaf(1.0f, 1.0f, 1.0f) == 2.0f);
static_assert(__builtin_fmaf(1.0f, -1.0f, 1.0f) == 0.0f);
static_assert(__builtin_fmal(1.0L, 1.0L, 1.0L) == 2.0L);
static_assert(__builtin_fmal(1.0L, -1.0L, 1.0L) == 0.0L);
} // namespace fma

namespace abs {
static_assert(__builtin_abs(14) == 14, "");
static_assert(__builtin_labs(14L) == 14L, "");
Expand Down
7 changes: 7 additions & 0 deletions clang/test/Sema/constant-builtins-2.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ long double g18 = __builtin_copysignl(1.0L, -1.0L);
__float128 g18_2 = __builtin_copysignf128(1.0q, -1.0q);
#endif

double g19 = __builtin_fma(1.0, 1.0, 1.0);
float g20 = __builtin_fmaf(1.0f, 1.0f, 1.0f);
long double g21 = __builtin_fmal(1.0L, 1.0L, 1.0L);
#if defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)
__float128 g21_2 = __builtin_fmaf128(1.0q, 1.0q, 1.0q);
#endif

char classify_nan [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nan(""))];
char classify_snan [__builtin_fpclassify(+1, -1, -1, -1, -1, __builtin_nans(""))];
char classify_inf [__builtin_fpclassify(-1, +1, -1, -1, -1, __builtin_inf())];
Expand Down
Loading