Skip to content

Commit 0649d0d

Browse files
authored
fix various fma issues (#533)
This makes `src/builtins.jl` in sync with `bin/generate_builtins.jl` again. A lot of this code was also out-of-date, now that we only support 1.6, so this includes some cleanup.
1 parent f74effe commit 0649d0d

File tree

4 files changed

+75
-0
lines changed

4 files changed

+75
-0
lines changed

bin/generate_builtins.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,32 @@ function maybe_evaluate_builtin(frame, call_expr, expand::Bool)
268268
"""
269269
if isa(f, Core.IntrinsicFunction)
270270
cargs = getargs(args, frame)
271+
@static if isdefined(Core.Intrinsics, :have_fma)
272+
if f === Core.Intrinsics.have_fma && length(cargs) == 1
273+
cargs1 = cargs[1]
274+
if cargs1 == Float64
275+
return Some{Any}(FMA_FLOAT64[])
276+
elseif cargs1 == Float32
277+
return Some{Any}(FMA_FLOAT32[])
278+
elseif cargs1 == Float16
279+
return Some{Any}(FMA_FLOAT16[])
280+
end
281+
end
282+
end
283+
if f === Core.Intrinsics.muladd_float && length(cargs) == 3
284+
a, b, c = cargs
285+
Ta, Tb, Tc = typeof(a), typeof(b), typeof(c)
286+
if !(Ta == Tb == Tc)
287+
error("muladd_float: types of a, b, and c must match")
288+
end
289+
if Ta == Float64 && FMA_FLOAT64[]
290+
f = Core.Intrinsics.fma_float
291+
elseif Ta == Float32 && FMA_FLOAT32[]
292+
f = Core.Intrinsics.fma_float
293+
elseif Ta == Float16 && FMA_FLOAT16[]
294+
f = Core.Intrinsics.fma_float
295+
end
296+
end
271297
return Some{Any}(ccall(:jl_f_intrinsic_call, Any, (Any, Ptr{Any}, UInt32), f, cargs, length(cargs)))
272298
""")
273299
print(io,

src/builtins.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,32 @@ function maybe_evaluate_builtin(frame, call_expr, expand::Bool)
297297
end
298298
if isa(f, Core.IntrinsicFunction)
299299
cargs = getargs(args, frame)
300+
@static if isdefined(Core.Intrinsics, :have_fma)
301+
if f === Core.Intrinsics.have_fma && length(cargs) == 1
302+
cargs1 = cargs[1]
303+
if cargs1 == Float64
304+
return Some{Any}(FMA_FLOAT64[])
305+
elseif cargs1 == Float32
306+
return Some{Any}(FMA_FLOAT32[])
307+
elseif cargs1 == Float16
308+
return Some{Any}(FMA_FLOAT16[])
309+
end
310+
end
311+
end
312+
if f === Core.Intrinsics.muladd_float && length(cargs) == 3
313+
a, b, c = cargs
314+
Ta, Tb, Tc = typeof(a), typeof(b), typeof(c)
315+
if !(Ta == Tb == Tc)
316+
error("muladd_float: types of a, b, and c must match")
317+
end
318+
if Ta == Float64 && FMA_FLOAT64[]
319+
f = Core.Intrinsics.fma_float
320+
elseif Ta == Float32 && FMA_FLOAT32[]
321+
f = Core.Intrinsics.fma_float
322+
elseif Ta == Float16 && FMA_FLOAT16[]
323+
f = Core.Intrinsics.fma_float
324+
end
325+
end
300326
return Some{Any}(ccall(:jl_f_intrinsic_call, Any, (Any, Ptr{Any}, UInt32), f, cargs, length(cargs)))
301327
end
302328
if isa(f, typeof(kwinvoke))

src/packagedef.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ function set_compiled_methods()
116116
push!(compiled_modules, Base.Threads)
117117
end
118118

119+
_have_fma_compiled(::Type{T}) where {T} = Core.Intrinsics.have_fma(T)
120+
121+
const FMA_FLOAT64 = Ref(false)
122+
const FMA_FLOAT32 = Ref(false)
123+
const FMA_FLOAT16 = Ref(false)
124+
119125
function __init__()
120126
set_compiled_methods()
121127
COVERAGE[] = Base.JLOptions().code_coverage
@@ -144,6 +150,12 @@ function __init__()
144150
# compiled_calls[(qsym, RT, Core.svec(AT...), Core.Compiler)] = f
145151
# precompile(f, AT)
146152
# end
153+
154+
@static if isdefined(Base, :have_fma)
155+
FMA_FLOAT64[] = _have_fma_compiled(Float64)
156+
FMA_FLOAT32[] = _have_fma_compiled(Float32)
157+
FMA_FLOAT16[] = _have_fma_compiled(Float16)
158+
end
147159
end
148160

149161
include("precompile.jl")

test/interpret.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,3 +895,14 @@ end
895895
iscallexpr(ex::Expr) = ex.head === :call
896896
@test (@interpret iscallexpr(:(sin(3.14))))
897897
end
898+
899+
if isdefined(Base, :have_fma)
900+
f_fma() = Base.have_fma(Float64)
901+
@testset "fma" begin
902+
@test (@interpret f_fma()) == f_fma()
903+
a, b, c = (1.0585073227945125, -0.00040303348596386557, 1.5051263504758005e-16)
904+
@test (@interpret muladd(a, b, c)) === muladd(a,b,c)
905+
a = 1.0883740903666346; b = 2/3
906+
@test (@interpret a^b) === a^b
907+
end
908+
end

0 commit comments

Comments
 (0)