diff --git a/src/IRTools.jl b/src/IRTools.jl index b4bb473..4f9b039 100644 --- a/src/IRTools.jl +++ b/src/IRTools.jl @@ -47,7 +47,7 @@ let exports = :[ definitions, usages, dominators, domtree, domorder, domorder!, renumber, merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir, Simple, Loop, Multiple, reloop, stackify, functional, cond, WorkQueue, - Graph, liveness, interference, colouring, inline, + Graph, liveness, interference, colouring, inline, dependencies, # Reflection, Dynamo Meta, Lambda, meta, dynamo, transform, refresh, recurse!, self, varargs!, slots!, diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 11ee3a8..8302615 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -41,6 +41,97 @@ function usages(b::Block) return uses end +usages(st::Statement) = usages(st.expr) +usages(ex) = Set{Variable}() + +function usages(ex::Expr) + uses = Set{Variable}() + for x in ex.args + x isa Variable && push!(uses, x) + end + return uses +end + +function block_changes_deps(deps, ir, b) + for (v, st) in b + if haskey(deps, v) + (usages(st) ⊆ deps[v]) || return true + else + return true + end + end + + brs = branches(b) + for br in brs + if br.block > 0 + next_block = block(ir, br.block) + if !isempty(br.args) + for (x, y) in zip(arguments(next_block), br.args) + haskey(deps, x) && (y in deps[x]) && return true + end + end + end + end + return false +end + +function update_deps!(deps, v, direct) + set = get!(deps, v, Set{Variable}()) + union!(set, setdiff(direct, (v, ))) + + for x in direct + if (v != x) && haskey(deps, x) && !(deps[x] ⊆ set) + update_deps!(deps, v, deps[x]) + end + end + return deps +end + +""" + dependencies(ir::IR) + +Return the list of direct dependencies for each variable. +""" +function dependencies(ir::IR) + worklist = [block(ir, 1)] + deps = Dict() + while !isempty(worklist) + b = pop!(worklist) + for (v, st) in b + update_deps!(deps, v, usages(st)) + end + + brs = branches(b) + jump_next_block = true + for br in brs + if br.condition === nothing + jump_next_block = false + end + + if br.block > 0 # reachable + next_block = block(ir, br.block) + if !isempty(br.args) # pass arguments + for (x, y) in zip(arguments(next_block), br.args) + y isa Variable && update_deps!(deps, x, (y, )) + end + end + + if block_changes_deps(deps, ir, next_block) + push!(worklist, next_block) + end + end + end + + if jump_next_block + next_block = block(ir, b.id + 1) + if block_changes_deps(deps, ir, next_block) + push!(worklist, next_block) + end + end + end + return deps +end + function usecounts(ir::IR) counts = Dict{Variable,Int}() prewalk(ir) do x diff --git a/test/analysis.jl b/test/analysis.jl index 9123989..31940d6 100644 --- a/test/analysis.jl +++ b/test/analysis.jl @@ -1,5 +1,5 @@ using IRTools, Test -using IRTools: CFG, dominators, domtree +using IRTools: CFG, dominators, domtree, dependencies, var relu(x) = (y = x > 0 ? x : 0) ir = @code_ir relu(1) @@ -10,3 +10,45 @@ ir = @code_ir relu(1) @test domtree(CFG(ir)) == (1 => [2 => [], 3 => [], 4 => []]) @test domtree(CFG(ir)', entry = 4) == (4 => [1 => [], 2 => [], 3 => []]) + +function f(x) + x = sin(x) + y = cos(x) + + if x > 1 + x = cos(x) + 1 + else + x = y + 1 + end + return x +end + +ir = @code_ir f(1.0) + +deps = dependencies(ir) + +@test deps[var(9)] == Set(var.([2, 8, 7, 3, 4, 6])) +@test deps[var(8)] == Set(var.([3, 2, 4])) +@test deps[var(7)] == Set(var.([6, 3, 2])) +@test deps[var(6)] == Set(var.([3, 2])) +@test deps[var(5)] == Set(var.([3, 2])) +@test deps[var(4)] == Set(var.([3, 2])) +@test deps[var(3)] == Set([var(2)]) + +function pow(x, n) + r = 1 + while n > 0 + n -= 1 + r *= x + end + return r +end + +ir = @code_ir pow(1.0, 2) +deps = dependencies(ir) + +@test deps[var(8)] == Set(var.([5, 2])) +@test deps[var(7)] == Set(var.([4, 3])) +@test deps[var(6)] == Set(var.([4, 3])) +@test deps[var(5)] == Set(var.([2, 8])) +@test deps[var(4)] == Set(var.([3, 7]))