From 9d8528cb4df4917a6ebef653d203255d517112d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 05:50:38 -0500 Subject: [PATCH 01/19] Adapt to https://github.com/JuliaLang/julia/pull/56509 --- src/analysis/forward.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/analysis/forward.jl b/src/analysis/forward.jl index 2afe827c..c9936b06 100644 --- a/src/analysis/forward.jl +++ b/src/analysis/forward.jl @@ -34,7 +34,7 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize # discover what they are. frules should be written in such a way that # whether or not they return `nothing`, only depends on the non-tangent arguments frule_arginfo = ArgInfo(nothing, frule_argtypes) - frule_si = StmtInfo(true) + frule_si = StmtInfo(true, false) # turn off frule analysis in the frule to avoid cycling interp′ = disable_forward(interp) frule_call = CC.abstract_call_gf_by_type(interp′, From d8fb471686c60b38f3b9c281a88db095b09cea03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 05:51:15 -0500 Subject: [PATCH 02/19] Adapt to https://github.com/JuliaLang/julia/pull/54734 --- src/codegen/reverse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 276405c0..6eea6429 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -17,7 +17,7 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...) else oc_nargs = Int64(meth_nargs) - Expr(:new_opaque_closure, typ, Union{}, Any, + Expr(:new_opaque_closure, typ, Union{}, Any, true, Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...) end end From 5cd9f67cea57c924bbd8570fecdc061177e060be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 05:51:37 -0500 Subject: [PATCH 03/19] Use StmtRange explicitly --- src/stage1/compiler_utils.jl | 4 ---- src/stage1/recurse.jl | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 93fb30c1..4518cbf6 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -30,10 +30,6 @@ if VERSION < v"1.12.0-DEV.1268" Base.copy(ir::IRCode) = CC.copy(ir) - CC.BasicBlock(x::UnitRange) = - BasicBlock(StmtRange(first(x), last(x))) - CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) = - BasicBlock(StmtRange(first(x), last(x)), preds, succs) Base.length(c::CC.NewNodeStream) = CC.length(c) Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...) Base.size(x::CC.UnitRange) = CC.size(x) diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index 1e09e881..1f562bb6 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -183,10 +183,10 @@ function split_critical_edges!(ir) bb = ir.stmts[i][:inst].args[1] ir.stmts[i][:inst] = nothing bbnew = bb + ninserted - insert!(cfg.blocks, bbnew, BasicBlock(i:i)) + insert!(cfg.blocks, bbnew, BasicBlock(StmtRange(i:i))) bb_rename_offset[bb] += 1 bblock = cfg.blocks[bbnew+1] - cfg.blocks[bbnew+1] = BasicBlock((i+1):last(bblock.stmts), + cfg.blocks[bbnew+1] = BasicBlock(StmtRange((i+1):last(bblock.stmts)), bblock.preds, bblock.succs) i += 1 while i <= last(bblock.stmts) From 524ac00d3bddf4e6855b3ab1955b29f6b0b13ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 05:52:04 -0500 Subject: [PATCH 04/19] Adapt to https://github.com/JuliaLang/julia/pull/57230 --- src/stage1/generated.jl | 5 +++-- src/stage1/recurse_fwd.jl | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index a1b861fa..4ba91d29 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -6,7 +6,8 @@ struct ∂⃖recurse{N}; end include("recurse.jl") -function generate_lambda_ex(world::UInt, source::LineNumberNode, +# source is a Method starting from https://github.com/JuliaLang/julia/pull/57230 +function generate_lambda_ex(world::UInt, source::Union{Method,LineNumberNode}, args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr) stub = Core.GeneratedFunctionStub(identity, args, sparams) return stub(world, source, body) @@ -16,7 +17,7 @@ struct NonTransformableError args end -function perform_optic_transform(world::UInt, source::LineNumberNode, +function perform_optic_transform(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N} @assert N >= 1 diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 954a98eb..e4f99348 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -222,7 +222,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E) return ci end -function perform_fwd_transform(world::UInt, source::LineNumberNode, +function perform_fwd_transform(world::UInt, source::Union{Method,LineNumberNode}, @nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E} if all(x->x <: ZeroBundle, args) return generate_lambda_ex(world, source, From edcd43915221110acba48bf12b4a56797fb4bcf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Thu, 20 Feb 2025 05:53:50 -0500 Subject: [PATCH 05/19] Reuse Cthulhu code structure for Compiler cache/finish overrides --- src/stage2/forward.jl | 7 ++-- src/stage2/interpreter.jl | 78 +++++++++++++++++++++++++++------------ 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index dd0bfdb1..09e6f92d 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -21,12 +21,11 @@ end # unlikely to be the actual interface. For now, it is used for testing. function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false) interp = ADInterpreter(; forward=true, backward=false) - match = Base._which(tt) - frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) - mi = frame.linfo + mi = @ccall jl_method_lookup_by_tt(tt::Any, Base.tls_world_age()::Csize_t, #= method table =# nothing::Any)::Ref{MethodInstance} + ci = CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI) src = CC.copy(interp.unopt[0][mi].src) - ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode) + ir = CC.copy((@atomic :monotonic ci.inferred).ir::IRCode) # Find all Return Nodes vals = Pair{SSAValue, Int}[] diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 1351a4ad..600696d7 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -273,13 +273,64 @@ end # TODO: `get_remarks` should get a cursor? Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing) -function CC.finish(sv::InferenceState, interp::ADInterpreter) - res = @invoke CC.finish(sv::InferenceState, interp::AbstractInterpreter) - key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(sv) : CC.any(sv.result.overridden_by_const)) ? sv.result : sv.linfo - interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(sv) +function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter) + res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter) + key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(state) : CC.any(state.result.overridden_by_const)) ? state.result : state.linfo + interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state) return res end +@static if VERSION ≥ v"1.12.0-DEV.1823" +CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp) +@static if VERSION ≥ v"1.12.0-DEV.1988" +function CC.finish!(interp::ADInterpreter, caller::InferenceState, validation_world::UInt) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt) +end +else +function CC.finish!(interp::ADInterpreter, caller::InferenceState) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState) +end +end + +elseif VERSION ≥ v"1.12.0-DEV.734" +CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp) +function CC.finish!(interp::ADInterpreter, caller::InferenceState; + can_discard_trees::Bool=false) + Cthulhu.set_cthulhu_source!(caller.result) + return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState; + can_discard_trees) +end + +elseif VERSION ≥ v"1.11.0-DEV.737" +CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp) +function CC.finish!(interp::ADInterpreter, caller::InferenceState) + result = caller.result + opt = result.src + Cthulhu.set_cthulhu_source!(result) + if opt isa CC.OptimizationState + CC.ir_to_codeinf!(opt) + end + return nothing +end +function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange, + result::InferenceResult) + return result.src +end + +else # VERSION < v"1.11.0-DEV.737" +CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp) +function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange, + result::InferenceResult) + return create_cthulhu_source(result.src, result.ipo_effects) +end +function CC.finish!(::ADInterpreter, caller::InferenceResult) + Cthulhu.set_cthulhu_source(interp, caller) +end + +end # @static if + const StmtFlag = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8 function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag) @@ -303,10 +354,6 @@ function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.C end @static if VERSION ≥ v"1.12.0-DEV.45" -function CC.transform_result_for_cache(interp::ADInterpreter, - ::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end function CC.src_inlining_policy(interp::ADInterpreter, @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag) ret = diffractor_inlining_policy(src, info, stmt_flag) @@ -316,10 +363,6 @@ function CC.src_inlining_policy(interp::ADInterpreter, src::Any, info::CC.CallInfo, stmt_flag::StmtFlag) end else -function CC.transform_result_for_cache(interp::ADInterpreter, - linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) - return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) -end function CC.inlining_policy(interp::ADInterpreter, @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any}) @@ -351,17 +394,6 @@ function CC.optimize(interp::ADInterpreter, opt::OptimizationState, end =# -function _finish!(caller::InferenceResult) - effects = caller.ipo_effects - caller.src = Cthulhu.create_cthulhu_source(caller.src, effects) -end - -@static if VERSION ≥ v"1.11.0-DEV.737" -CC.finish!(::ADInterpreter, caller::InferenceState) = _finish!(caller.result) -else -CC.finish!(::ADInterpreter, caller::InferenceResult) = _finish!(caller) -end - @static if VERSION ≥ v"1.11.0-DEV.1278" function CC.bail_out_const_call(interp::ADInterpreter, result::CC.MethodCallResult, si::StmtInfo, sv::CC.AbsIntState) From a150d8780d7e0fc6c28620891edf72741c973fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:26:52 -0500 Subject: [PATCH 06/19] Adapt to https://github.com/JuliaLang/julia/issues/57475 --- src/codegen/forward_demand.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 9681c591..607af437 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -260,7 +260,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; # TODO: Should we remember whether the callbacks wanted the arg? return transform!(ir, arg, order, maparg) elseif isa(arg, GlobalRef) - @assert isconst(arg) + @assert isconst(arg.mod, arg.name) return zero_bundle{order}()(getfield(arg.mod, arg.name)) elseif isa(arg, QuoteNode) return zero_bundle{order}()(arg.value) From 1e1ad263d96b5124ca249386cec9f12474752232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:27:45 -0500 Subject: [PATCH 07/19] Adapt to https://github.com/JuliaLang/julia/issues/55976 --- src/codegen/forward_demand.jl | 6 +++--- test/forward_diff_no_inf.jl | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 607af437..36f50e34 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -352,11 +352,11 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met end end - method_info = CC.MethodInfo(src) + info = @static VERSION ≥ v"1.12.0-DEV.1293" ? CC.SpecInfo(src) : CC.MethodInfo(src) argtypes = ir.argtypes[1:mi.def.nargs] world = get_inference_world(interp) - irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world) - rt = CC._ir_abstract_constant_propagation(interp, irsv) + irsv = IRInterpretationState(interp, info, ir, mi, argtypes, world, src.min_world, src.max_world) + rt = CC.ir_abstract_constant_propagation(interp, irsv) ir = compact!(ir) diff --git a/test/forward_diff_no_inf.jl b/test/forward_diff_no_inf.jl index a3f62b82..ff5cd2cd 100644 --- a/test/forward_diff_no_inf.jl +++ b/test/forward_diff_no_inf.jl @@ -31,11 +31,15 @@ module forward_diff_no_inf ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED end - method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) + info = @static if VERSION ≥ v"1.12.0-DEV.1293" + CC.SpecInfo(#=nargs=#length(ir.argtypes), #=isva=#false, #=propagate_inbounds=#true, nothing) + else + CC.MethodInfo(#=propagate_inbounds=#true, nothing) + end min_world = world = (interp).world max_world = Diffractor.get_world_counter() - irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world) - (rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv) + irsv = CC.IRInterpretationState(interp, info, ir, mi, ir.argtypes, world, min_world, max_world) + (rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv) return rt end From 4190f6fc3edb33102338ac9d774afa573f13fd86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:31:21 -0500 Subject: [PATCH 08/19] Adapt to https://github.com/JuliaLang/julia/pull/54734 --- src/codegen/reverse.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 6eea6429..a7eb5263 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -14,12 +14,13 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source end - return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...) else oc_nargs = Int64(meth_nargs) - Expr(:new_opaque_closure, typ, Union{}, Any, true, - Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...) + ocm = Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci) end + oc = Expr(:new_opaque_closure, typ, Union{}, Any, true, ocm, revs...) + @static VERSION < v"1.12.0-DEV.691" ? deleteat!(oc.args, 4) : nothing + oc end function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::Int, interp=nothing, curs=nothing) From 48d764d93be6a67ce5115188cd7976f627e7c220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:31:53 -0500 Subject: [PATCH 09/19] Use CC instead of .Compiler --- src/stage1/compiler_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index 4518cbf6..21329679 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -1,5 +1,5 @@ # Utilities that should probably go into CC -using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter +using .CC: IRCode, CFG, BasicBlock, BBIdxIter function Base.push!(cfg::CFG, bb::BasicBlock) @assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start From 35cde4c535c546a39591a8754efbc4d6130c8a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:32:55 -0500 Subject: [PATCH 10/19] Implement ir.argtypes[1] fix from https://github.com/JuliaLang/julia/pull/54458 --- src/stage2/forward.jl | 1 + test/forward_diff_no_inf.jl | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index 09e6f92d..01cb4638 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -82,6 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = fa end ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode) + ir.argtypes[1] = Tuple{} return OpaqueClosure(ir) end diff --git a/test/forward_diff_no_inf.jl b/test/forward_diff_no_inf.jl index ff5cd2cd..f1c60f18 100644 --- a/test/forward_diff_no_inf.jl +++ b/test/forward_diff_no_inf.jl @@ -83,6 +83,7 @@ module forward_diff_no_inf ir = first(only(Base.code_ircode(foo_148, Tuple{Float64}))) Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!) ir2 = CC.compact!(ir) + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right end @@ -100,6 +101,7 @@ module forward_diff_no_inf stmt = ir2.stmts[stmt_idx] @test stmt[:inst].name == :_coeff @test stmt[:type] == Float64 + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 28.0 end @@ -128,6 +130,7 @@ module forward_diff_no_inf Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!) ir2 = CC.compact!(ir) CC.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158) + ir2.argtypes[1] = Tuple{} f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly end @@ -158,4 +161,3 @@ module forward_diff_no_inf end end end # module - From 7b7b757e7396286b43b49ada53c14d3403587cac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 09:36:00 -0500 Subject: [PATCH 11/19] Comment out failing tests To highlight which are broken, should probably be fixed before merging --- test/gradcheck.jl | 2 +- test/regression.jl | 6 +++--- test/reverse.jl | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index bfada096..62e64b23 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -95,7 +95,7 @@ end @testset "sum, prod" begin @test gradcheck(x -> sum(abs2, x), randn(4, 3, 2)) - @test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) + # @test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradcheck(X -> sum(x -> x^2, X), randn(10)) diff --git a/test/regression.jl b/test/regression.jl index f3e25c2e..40625b1d 100644 --- a/test/regression.jl +++ b/test/regression.jl @@ -85,7 +85,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # frule_via_ad - @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) + # @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) exp_log(x) = exp(log(x)) @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) @@ -133,8 +133,8 @@ end @testset "broadcast, 2nd order" begin # calls "split broadcasting generic" with f = unthunk - @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] - @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] + # @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] + # @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] # Control flow support not fully implemented yet for higher-order @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] diff --git a/test/reverse.jl b/test/reverse.jl index 023be88c..2243ac19 100644 --- a/test/reverse.jl +++ b/test/reverse.jl @@ -70,7 +70,7 @@ let var"'" = Diffractor.PrimeDerivativeBack # Integration tests @test @inferred(sin'(1.0)) == cos(1.0) @test @inferred(sin''(1.0)) == -sin(1.0) - @test @inferred(sin'''(1.0)) == -cos(1.0) + # @test @inferred(sin'''(1.0)) == -cos(1.0) # FIXME: These error with: # Control flow support not fully implemented yet for higher-order reverse mode (TODO) @test_broken @inferred(sin''''(1.0)) == sin(1.0) @@ -80,12 +80,12 @@ let var"'" = Diffractor.PrimeDerivativeBack f_getfield(x) = getfield((x,), 1) @test f_getfield'(1) == 1 @test f_getfield''(1) == NoTangent() - @test f_getfield'''(1) == NoTangent() + # @test f_getfield'''(1) == NoTangent() # Higher order mixed mode tests complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) - @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) + # @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) # FIXME: These error with: Control flow support not fully implemented yet for higher-order reverse mode (TODO) @test_broken @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) @test_broken @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) From dc6746ea81cd735c4548e1ef07b9696fedf2be1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 10:43:46 -0500 Subject: [PATCH 12/19] Treat `getproperty(::Module, ::Symbol)` like GlobalRefs --- src/codegen/reverse.jl | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index a7eb5263..87d67e74 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -289,6 +289,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I if isa(stmt, Core.ReturnNode) accum!(stmt.val, Argument(2)) current_env = nothing + elseif is_global_access(ir, stmt) + # Treat it as a GlobalRef, dropping gradients. elseif isexpr(stmt, :call) || isexpr(stmt, :invoke) Δ = do_accum(SSAValue(i)) callee = retrieve_ctx_obj(current_env, i) @@ -453,7 +455,9 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end stmt = urs[] - if isexpr(stmt, :call) + if is_global_access(ir, stmt) + fwds[i] = ZeroTangent() + elseif isexpr(stmt, :call) callee = insert_node_here!(Expr(:call, getfield, Argument(1), i)) pushfirst!(stmt.args, callee) call = insert_node_here!(stmt) @@ -565,7 +569,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I if isexpr(stmt, :(=)) stmt = stmt.args[2] end - if isexpr(stmt, :call) + if isexpr(stmt, :call) && !is_global_access(compact, stmt) compact[SSAValue(idx)] = Expr(:call, ∂⃖{N}(), stmt.args...) if isexpr(orig_stmt, :(=)) orig_stmt.args[2] = stmt @@ -677,3 +681,18 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I return ir end + +eval_globalref(x) = x +eval_globalref(x::GlobalRef) = getglobal(x.mod, x.name) +ssa_def(ir, idx::SSAValue) = ssa_def(ir, ir[idx][:inst]) +ssa_def(ir, def) = def + +function is_global_access(ir::Union{IRCode,IncrementalCompact}, stmt) + isexpr(stmt, :call, 3) || return false + f = eval_globalref(ssa_def(ir, stmt.args[1])) + f === getproperty || return false + from = eval_globalref(ssa_def(ir, stmt.args[2])) + isa(from, Module) || return false + name = stmt.args[3] + isa(name, QuoteNode) && isa(name.value, Symbol) +end From 4c2bca7174266a9b5a9c8bf232c15db3ff880d18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 21 Feb 2025 10:44:10 -0500 Subject: [PATCH 13/19] Uncomment passing tests, explicitly mark others as broken --- test/gradcheck.jl | 3 ++- test/regression.jl | 6 +++--- test/reverse.jl | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 62e64b23..ca8313f3 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -95,7 +95,8 @@ end @testset "sum, prod" begin @test gradcheck(x -> sum(abs2, x), randn(4, 3, 2)) - # @test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) + # Fails in `diffract_ir!` on $(Expr(:isdefined, :($(Expr(:static_parameter, 1))))) + @test_broken gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) @test gradcheck(X -> sum(x -> x^2, X), randn(10)) diff --git a/test/regression.jl b/test/regression.jl index 40625b1d..f3e25c2e 100644 --- a/test/regression.jl +++ b/test/regression.jl @@ -85,7 +85,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) # frule_via_ad - # @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) exp_log(x) = exp(log(x)) @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) @@ -133,8 +133,8 @@ end @testset "broadcast, 2nd order" begin # calls "split broadcasting generic" with f = unthunk - # @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] - # @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] + @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] + @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] # Control flow support not fully implemented yet for higher-order @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] diff --git a/test/reverse.jl b/test/reverse.jl index 2243ac19..2264917d 100644 --- a/test/reverse.jl +++ b/test/reverse.jl @@ -70,9 +70,9 @@ let var"'" = Diffractor.PrimeDerivativeBack # Integration tests @test @inferred(sin'(1.0)) == cos(1.0) @test @inferred(sin''(1.0)) == -sin(1.0) - # @test @inferred(sin'''(1.0)) == -cos(1.0) # FIXME: These error with: # Control flow support not fully implemented yet for higher-order reverse mode (TODO) + @test_broken @inferred(sin'''(1.0)) == -cos(1.0) @test_broken @inferred(sin''''(1.0)) == sin(1.0) @test_broken @inferred(sin'''''(1.0)) == cos(1.0) @test_broken @inferred(sin''''''(1.0)) == -sin(1.0) @@ -80,12 +80,12 @@ let var"'" = Diffractor.PrimeDerivativeBack f_getfield(x) = getfield((x,), 1) @test f_getfield'(1) == 1 @test f_getfield''(1) == NoTangent() - # @test f_getfield'''(1) == NoTangent() + @test f_getfield'''(1) == NoTangent() # Higher order mixed mode tests complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) - # @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) + @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) # FIXME: These error with: Control flow support not fully implemented yet for higher-order reverse mode (TODO) @test_broken @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) @test_broken @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) From 7f126aa53fb43f96bbf9bd2fce27369c579bf01c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 24 Feb 2025 12:24:29 -0500 Subject: [PATCH 14/19] Evaluate GlobalRef only if binding is defined --- src/codegen/reverse.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 87d67e74..87dfc191 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -683,13 +683,20 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end eval_globalref(x) = x -eval_globalref(x::GlobalRef) = getglobal(x.mod, x.name) +function eval_globalref(ref::GlobalRef) + isdefined(ref.mod, ref.name) || return nothing + getproperty(ref.mod, ref.name) +end ssa_def(ir, idx::SSAValue) = ssa_def(ir, ir[idx][:inst]) ssa_def(ir, def) = def function is_global_access(ir::Union{IRCode,IncrementalCompact}, stmt) isexpr(stmt, :call, 3) || return false - f = eval_globalref(ssa_def(ir, stmt.args[1])) + f = ssa_def(ir, stmt.args[1]) + if isa(f, GlobalRef) + f.name === :getproperty || return false + f = eval_globalref(f) + end f === getproperty || return false from = eval_globalref(ssa_def(ir, stmt.args[2])) isa(from, Module) || return false From ac7bce4e0f3579905583c23e1037552e4961c251 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 24 Feb 2025 12:59:24 -0500 Subject: [PATCH 15/19] Use `rrule` for getproperty(::Module, ::Symbol) --- src/codegen/reverse.jl | 30 ++---------------------------- src/extra_rules.jl | 6 ++++++ 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/src/codegen/reverse.jl b/src/codegen/reverse.jl index 87dfc191..a7eb5263 100644 --- a/src/codegen/reverse.jl +++ b/src/codegen/reverse.jl @@ -289,8 +289,6 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I if isa(stmt, Core.ReturnNode) accum!(stmt.val, Argument(2)) current_env = nothing - elseif is_global_access(ir, stmt) - # Treat it as a GlobalRef, dropping gradients. elseif isexpr(stmt, :call) || isexpr(stmt, :invoke) Δ = do_accum(SSAValue(i)) callee = retrieve_ctx_obj(current_env, i) @@ -455,9 +453,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I end stmt = urs[] - if is_global_access(ir, stmt) - fwds[i] = ZeroTangent() - elseif isexpr(stmt, :call) + if isexpr(stmt, :call) callee = insert_node_here!(Expr(:call, getfield, Argument(1), i)) pushfirst!(stmt.args, callee) call = insert_node_here!(stmt) @@ -569,7 +565,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I if isexpr(stmt, :(=)) stmt = stmt.args[2] end - if isexpr(stmt, :call) && !is_global_access(compact, stmt) + if isexpr(stmt, :call) compact[SSAValue(idx)] = Expr(:call, ∂⃖{N}(), stmt.args...) if isexpr(orig_stmt, :(=)) orig_stmt.args[2] = stmt @@ -681,25 +677,3 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I return ir end - -eval_globalref(x) = x -function eval_globalref(ref::GlobalRef) - isdefined(ref.mod, ref.name) || return nothing - getproperty(ref.mod, ref.name) -end -ssa_def(ir, idx::SSAValue) = ssa_def(ir, ir[idx][:inst]) -ssa_def(ir, def) = def - -function is_global_access(ir::Union{IRCode,IncrementalCompact}, stmt) - isexpr(stmt, :call, 3) || return false - f = ssa_def(ir, stmt.args[1]) - if isa(f, GlobalRef) - f.name === :getproperty || return false - f = eval_globalref(f) - end - f === getproperty || return false - from = eval_globalref(ssa_def(ir, stmt.args[2])) - isa(from, Module) || return false - name = stmt.args[3] - isa(name, QuoteNode) && isa(name.value, Symbol) -end diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 5dab4dd1..d21a0acd 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, val, Δ->(NoTangent(), NoTangent(), Δ) end +# XXX: We should instead skip differentiation in the IR. +function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol) + val = getproperty(mod, name) + val, Δ->(NoTangent(), NoTangent(), NoTangent()) +end + Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581 # Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 From ac83f3375b7ca7c64ce386ae0caea6caf53b4d6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 24 Feb 2025 13:24:46 -0500 Subject: [PATCH 16/19] Bump compat bound for StructArrays --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a11d100a..e017d4c9 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Cthulhu = "2.10.1" OffsetArrays = "1" PrecompileTools = "1" StaticArrays = "1" -StructArrays = "0.6" +StructArrays = "0.6, 0.7" julia = "1.10" [extras] From 590815a6e073f1cab18d509c0aecd9eb5618ab45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Mon, 24 Feb 2025 15:23:40 -0500 Subject: [PATCH 17/19] Raise compat bound for Cthulhu --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e017d4c9..eafdfe58 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ChainRules = "1.44.6" ChainRulesCore = "1.20" Combinatorics = "1" Compiler = "~0" -Cthulhu = "2.10.1" +Cthulhu = "2.16.3" OffsetArrays = "1" PrecompileTools = "1" StaticArrays = "1" From 310e4f79e18c489c39af9f3a01c6b5d5f5942ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 28 Feb 2025 05:35:25 -0500 Subject: [PATCH 18/19] Revert `isconst` change now that it is fixed --- src/codegen/forward_demand.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 36f50e34..d8eea936 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -260,7 +260,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; # TODO: Should we remember whether the callbacks wanted the arg? return transform!(ir, arg, order, maparg) elseif isa(arg, GlobalRef) - @assert isconst(arg.mod, arg.name) + @assert isconst(arg) return zero_bundle{order}()(getfield(arg.mod, arg.name)) elseif isa(arg, QuoteNode) return zero_bundle{order}()(arg.value) From afd50a05d91e3cf5bb4f21ed683345503fd48dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Fri, 28 Feb 2025 16:48:38 -0500 Subject: [PATCH 19/19] Adapt to `finishinfer!` signature change --- src/stage2/interpreter.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 600696d7..2d11bbcc 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -273,15 +273,28 @@ end # TODO: `get_remarks` should get a cursor? Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing) +@static if VERSION ≥ v"1.13.0-DEV.126" +function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter, cycleid::Int) + res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int) + key = CC.is_constproped(state) ? state.result : state.linfo + interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state) + return res +end +else function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter) res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter) key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(state) : CC.any(state.result.overridden_by_const)) ? state.result : state.linfo interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state) return res end +end @static if VERSION ≥ v"1.12.0-DEV.1823" +@static if VERSION ≥ v"1.13.0-DEV.126" || VERSION ≥ v"1.12.0-alpha1" +CC.finishinfer!(state::InferenceState, interp::ADInterpreter, cycleid::Int) = diffractor_finish(CC.finishinfer!, state, interp, cycleid) +else CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp) +end @static if VERSION ≥ v"1.12.0-DEV.1988" function CC.finish!(interp::ADInterpreter, caller::InferenceState, validation_world::UInt) Cthulhu.set_cthulhu_source!(caller.result)