Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9d8528c
Adapt to https://github.com/JuliaLang/julia/pull/56509
serenity4 Feb 20, 2025
d8fb471
Adapt to https://github.com/JuliaLang/julia/pull/54734
serenity4 Feb 20, 2025
5cd9f67
Use StmtRange explicitly
serenity4 Feb 20, 2025
524ac00
Adapt to https://github.com/JuliaLang/julia/pull/57230
serenity4 Feb 20, 2025
edcd439
Reuse Cthulhu code structure for Compiler cache/finish overrides
serenity4 Feb 20, 2025
a150d87
Adapt to https://github.com/JuliaLang/julia/issues/57475
serenity4 Feb 21, 2025
1e1ad26
Adapt to https://github.com/JuliaLang/julia/issues/55976
serenity4 Feb 21, 2025
4190f6f
Adapt to https://github.com/JuliaLang/julia/pull/54734
serenity4 Feb 21, 2025
48d764d
Use CC instead of .Compiler
serenity4 Feb 21, 2025
35cde4c
Implement ir.argtypes[1] fix from https://github.com/JuliaLang/julia/…
serenity4 Feb 21, 2025
7b7b757
Comment out failing tests
serenity4 Feb 21, 2025
dc6746e
Treat `getproperty(::Module, ::Symbol)` like GlobalRefs
serenity4 Feb 21, 2025
4c2bca7
Uncomment passing tests, explicitly mark others as broken
serenity4 Feb 21, 2025
7f126aa
Evaluate GlobalRef only if binding is defined
serenity4 Feb 24, 2025
ac7bce4
Use `rrule` for getproperty(::Module, ::Symbol)
serenity4 Feb 24, 2025
ac83f33
Bump compat bound for StructArrays
serenity4 Feb 24, 2025
590815a
Raise compat bound for Cthulhu
serenity4 Feb 24, 2025
310e4f7
Revert `isconst` change now that it is fixed
serenity4 Feb 28, 2025
afd50a0
Adapt to `finishinfer!` signature change
serenity4 Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ ChainRules = "1.44.6"
ChainRulesCore = "1.20"
Combinatorics = "1"
Compiler = "~0"
Cthulhu = "2.10.1"
Cthulhu = "2.16.3"
OffsetArrays = "1"
PrecompileTools = "1"
StaticArrays = "1"
StructArrays = "0.6"
StructArrays = "0.6, 0.7"
julia = "1.10"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/analysis/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
# discover what they are. frules should be written in such a way that
# whether or not they return `nothing`, only depends on the non-tangent arguments
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
frule_si = StmtInfo(true, false)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_gf_by_type(interp′,
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
end
end

method_info = CC.MethodInfo(src)
info = @static VERSION ≥ v"1.12.0-DEV.1293" ? CC.SpecInfo(src) : CC.MethodInfo(src)
argtypes = ir.argtypes[1:mi.def.nargs]
world = get_inference_world(interp)
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
rt = CC._ir_abstract_constant_propagation(interp, irsv)
irsv = IRInterpretationState(interp, info, ir, mi, argtypes, world, src.min_world, src.max_world)
rt = CC.ir_abstract_constant_propagation(interp, irsv)

ir = compact!(ir)

Expand Down
7 changes: 4 additions & 3 deletions src/codegen/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci,
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
end
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
else
oc_nargs = Int64(meth_nargs)
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
ocm = Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci)
end
oc = Expr(:new_opaque_closure, typ, Union{}, Any, true, ocm, revs...)
@static VERSION < v"1.12.0-DEV.691" ? deleteat!(oc.args, 4) : nothing
oc
end

function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::Int, interp=nothing, curs=nothing)
Expand Down
6 changes: 6 additions & 0 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
val, Δ->(NoTangent(), NoTangent(), Δ)
end

# XXX: We should instead skip differentiation in the IR.
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol)
val = getproperty(mod, name)
val, Δ->(NoTangent(), NoTangent(), NoTangent())
end

Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581

# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
Expand Down
6 changes: 1 addition & 5 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Utilities that should probably go into CC
using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter
using .CC: IRCode, CFG, BasicBlock, BBIdxIter

function Base.push!(cfg::CFG, bb::BasicBlock)
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
Expand Down Expand Up @@ -30,10 +30,6 @@ if VERSION < v"1.12.0-DEV.1268"

Base.copy(ir::IRCode) = CC.copy(ir)

CC.BasicBlock(x::UnitRange) =
BasicBlock(StmtRange(first(x), last(x)))
CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
Base.length(c::CC.NewNodeStream) = CC.length(c)
Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...)
Base.size(x::CC.UnitRange) = CC.size(x)
Expand Down
5 changes: 3 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ struct ∂⃖recurse{N}; end

include("recurse.jl")

function generate_lambda_ex(world::UInt, source::LineNumberNode,
# source is a Method starting from https://github.com/JuliaLang/julia/pull/57230
function generate_lambda_ex(world::UInt, source::Union{Method,LineNumberNode},
args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr)
stub = Core.GeneratedFunctionStub(identity, args, sparams)
return stub(world, source, body)
Expand All @@ -16,7 +17,7 @@ struct NonTransformableError
args
end

function perform_optic_transform(world::UInt, source::LineNumberNode,
function perform_optic_transform(world::UInt, source::Union{Method,LineNumberNode},
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
@assert N >= 1

Expand Down
4 changes: 2 additions & 2 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ function split_critical_edges!(ir)
bb = ir.stmts[i][:inst].args[1]
ir.stmts[i][:inst] = nothing
bbnew = bb + ninserted
insert!(cfg.blocks, bbnew, BasicBlock(i:i))
insert!(cfg.blocks, bbnew, BasicBlock(StmtRange(i:i)))
bb_rename_offset[bb] += 1
bblock = cfg.blocks[bbnew+1]
cfg.blocks[bbnew+1] = BasicBlock((i+1):last(bblock.stmts),
cfg.blocks[bbnew+1] = BasicBlock(StmtRange((i+1):last(bblock.stmts)),
bblock.preds, bblock.succs)
i += 1
while i <= last(bblock.stmts)
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
return ci
end

function perform_fwd_transform(world::UInt, source::LineNumberNode,
function perform_fwd_transform(world::UInt, source::Union{Method,LineNumberNode},
@nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E}
if all(x->x <: ZeroBundle, args)
return generate_lambda_ex(world, source,
Expand Down
8 changes: 4 additions & 4 deletions src/stage2/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ end
# unlikely to be the actual interface. For now, it is used for testing.
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false)
interp = ADInterpreter(; forward=true, backward=false)
match = Base._which(tt)
frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
mi = frame.linfo
mi = @ccall jl_method_lookup_by_tt(tt::Any, Base.tls_world_age()::Csize_t, #= method table =# nothing::Any)::Ref{MethodInstance}
ci = CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)

src = CC.copy(interp.unopt[0][mi].src)
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)
ir = CC.copy((@atomic :monotonic ci.inferred).ir::IRCode)

# Find all Return Nodes
vals = Pair{SSAValue, Int}[]
Expand Down Expand Up @@ -83,6 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = fa
end

ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode)
ir.argtypes[1] = Tuple{}

return OpaqueClosure(ir)
end
91 changes: 68 additions & 23 deletions src/stage2/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,76 @@ end
# TODO: `get_remarks` should get a cursor?
Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing)

function CC.finish(sv::InferenceState, interp::ADInterpreter)
res = @invoke CC.finish(sv::InferenceState, interp::AbstractInterpreter)
key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(sv) : CC.any(sv.result.overridden_by_const)) ? sv.result : sv.linfo
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(sv)
@static if VERSION ≥ v"1.13.0-DEV.126"
function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter, cycleid::Int)
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int)
key = CC.is_constproped(state) ? state.result : state.linfo
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state)
return res
end
else
function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter)
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter)
key = (@static VERSION ≥ v"1.12.0-DEV.317" ? CC.is_constproped(state) : CC.any(state.result.overridden_by_const)) ? state.result : state.linfo
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state)
return res
end
end

@static if VERSION ≥ v"1.12.0-DEV.1823"
@static if VERSION ≥ v"1.13.0-DEV.126" || VERSION ≥ v"1.12.0-alpha1"
CC.finishinfer!(state::InferenceState, interp::ADInterpreter, cycleid::Int) = diffractor_finish(CC.finishinfer!, state, interp, cycleid)
else
CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp)
end
@static if VERSION ≥ v"1.12.0-DEV.1988"
function CC.finish!(interp::ADInterpreter, caller::InferenceState, validation_world::UInt)
Cthulhu.set_cthulhu_source!(caller.result)
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
end
else
function CC.finish!(interp::ADInterpreter, caller::InferenceState)
Cthulhu.set_cthulhu_source!(caller.result)
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState)
end
end

elseif VERSION ≥ v"1.12.0-DEV.734"
CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp)
function CC.finish!(interp::ADInterpreter, caller::InferenceState;
can_discard_trees::Bool=false)
Cthulhu.set_cthulhu_source!(caller.result)
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState;
can_discard_trees)
end

elseif VERSION ≥ v"1.11.0-DEV.737"
CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp)
function CC.finish!(interp::ADInterpreter, caller::InferenceState)
result = caller.result
opt = result.src
Cthulhu.set_cthulhu_source!(result)
if opt isa CC.OptimizationState
CC.ir_to_codeinf!(opt)
end
return nothing
end
function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange,
result::InferenceResult)
return result.src
end

else # VERSION < v"1.11.0-DEV.737"
CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp)
function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange,
result::InferenceResult)
return create_cthulhu_source(result.src, result.ipo_effects)
end
function CC.finish!(::ADInterpreter, caller::InferenceResult)
Cthulhu.set_cthulhu_source(interp, caller)
end

end # @static if

const StmtFlag = @static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8
function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo),
Expand All @@ -303,10 +367,6 @@ function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.C
end

@static if VERSION ≥ v"1.12.0-DEV.45"
function CC.transform_result_for_cache(interp::ADInterpreter,
::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end
function CC.src_inlining_policy(interp::ADInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag)
ret = diffractor_inlining_policy(src, info, stmt_flag)
Expand All @@ -316,10 +376,6 @@ function CC.src_inlining_policy(interp::ADInterpreter,
src::Any, info::CC.CallInfo, stmt_flag::StmtFlag)
end
else
function CC.transform_result_for_cache(interp::ADInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end
function CC.inlining_policy(interp::ADInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag,
mi::MethodInstance, argtypes::Vector{Any})
Expand Down Expand Up @@ -351,17 +407,6 @@ function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
end
=#

function _finish!(caller::InferenceResult)
effects = caller.ipo_effects
caller.src = Cthulhu.create_cthulhu_source(caller.src, effects)
end

@static if VERSION ≥ v"1.11.0-DEV.737"
CC.finish!(::ADInterpreter, caller::InferenceState) = _finish!(caller.result)
else
CC.finish!(::ADInterpreter, caller::InferenceResult) = _finish!(caller)
end

@static if VERSION ≥ v"1.11.0-DEV.1278"
function CC.bail_out_const_call(interp::ADInterpreter, result::CC.MethodCallResult,
si::StmtInfo, sv::CC.AbsIntState)
Expand Down
14 changes: 10 additions & 4 deletions test/forward_diff_no_inf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ module forward_diff_no_inf
ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED
end

method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing)
info = @static if VERSION ≥ v"1.12.0-DEV.1293"
CC.SpecInfo(#=nargs=#length(ir.argtypes), #=isva=#false, #=propagate_inbounds=#true, nothing)
else
CC.MethodInfo(#=propagate_inbounds=#true, nothing)
end
min_world = world = (interp).world
max_world = Diffractor.get_world_counter()
irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world)
(rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv)
irsv = CC.IRInterpretationState(interp, info, ir, mi, ir.argtypes, world, min_world, max_world)
(rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv)
return rt
end

Expand Down Expand Up @@ -79,6 +83,7 @@ module forward_diff_no_inf
ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!)
ir2 = CC.compact!(ir)
ir2.argtypes[1] = Tuple{}
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right
end
Expand All @@ -96,6 +101,7 @@ module forward_diff_no_inf
stmt = ir2.stmts[stmt_idx]
@test stmt[:inst].name == :_coeff
@test stmt[:type] == Float64
ir2.argtypes[1] = Tuple{}
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(3.5) == 28.0
end
Expand Down Expand Up @@ -124,6 +130,7 @@ module forward_diff_no_inf
Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
ir2 = CC.compact!(ir)
CC.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
ir2.argtypes[1] = Tuple{}
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
end
Expand Down Expand Up @@ -154,4 +161,3 @@ module forward_diff_no_inf
end
end
end # module

3 changes: 2 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ end

@testset "sum, prod" begin
@test gradcheck(x -> sum(abs2, x), randn(4, 3, 2))
@test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10))
# Fails in `diffract_ir!` on $(Expr(:isdefined, :($(Expr(:static_parameter, 1)))))
@test_broken gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10))
@test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231
@test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10))
@test gradcheck(X -> sum(x -> x^2, X), randn(10))
Expand Down
2 changes: 1 addition & 1 deletion test/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
# Integration tests
@test @inferred(sin'(1.0)) == cos(1.0)
@test @inferred(sin''(1.0)) == -sin(1.0)
@test @inferred(sin'''(1.0)) == -cos(1.0)
# FIXME: These error with:
# Control flow support not fully implemented yet for higher-order reverse mode (TODO)
@test_broken @inferred(sin'''(1.0)) == -cos(1.0)
@test_broken @inferred(sin''''(1.0)) == sin(1.0)
@test_broken @inferred(sin'''''(1.0)) == cos(1.0)
@test_broken @inferred(sin''''''(1.0)) == -sin(1.0)
Expand Down
Loading