diff --git a/ext/GraphVizSimpleExt.jl b/ext/GraphVizSimpleExt.jl index 01d5e2810..a68487959 100644 --- a/ext/GraphVizSimpleExt.jl +++ b/ext/GraphVizSimpleExt.jl @@ -7,15 +7,17 @@ else end import Dagger -import Dagger: Chunk, Thunk, Processor +import Dagger: Chunk, Thunk, DTask, Processor import Dagger: show_logs +import Dagger: istask, dependents +import Dagger: unwrap_weak import Dagger.TimespanLogging: Timespan ### DAG-based graphing global _part_labels = Dict() -function write_node(ctx, io, t::Chunk, c) +function write_node(io, t::Chunk, c, ctx=nothing) _part_labels[t]="part_$c" c+1 end @@ -32,27 +34,109 @@ function node_id(t::Chunk) _part_labels[t] end +function node_name(t::Thunk) + "n_$(t.id)" +end + +function node_name(t::Chunk) + _part_labels[t] +end + +function node_name(name::String) + "n_$name" +end + +function node_name(id) + "n_$id" +end + +# Modified version of the function from Dagger compute.jl +function custom_dependents(node::Thunk) + deps = Dict{Union{Thunk,Chunk}, Set{Thunk}}() + visited = Set{Thunk}() + to_visit = Set{Thunk}() + push!(to_visit, node) + while !isempty(to_visit) + next = pop!(to_visit) + (next in visited) && continue + if !haskey(deps, next) + deps[next] = Set{Thunk}() + end + for inp in next.syncdeps + unwrapped = unwrap_weak(inp) + if (unwrapped === nothing) + continue + end + inp = unwrapped + if istask(inp) || (inp isa Chunk) + s = get!(()->Set{Thunk}(), deps, inp) + push!(s, next) + if istask(inp) && !(inp in visited) + push!(to_visit, inp) + end + end + end + push!(visited, next) + end + return deps +end + +# Writing DAG using DTask involves unwrapping WeakRefs, which might return `nothing` if garbage collected, +# so the part of the DAG is not displayed. This is an unstable behavior, so disabled by default. +function write_dag(io, e::DTask, stable::Bool=true) + if (stable) + throw(ArgumentError("Writing DAG for DTask is not supported by default. Use the logs instead.")) + else + t = convert_to_thunk(e) + write_dag(io, t) + end +end + function write_dag(io, t::Thunk) !istask(t) && return - deps = dependents(t) + + # Chunk/Thunk nodes + deps = custom_dependents(t) c=1 for k in keys(deps) - c = write_node(nothing, io, k, c) + c = write_node(io, k, c) end for (k, v) in deps for dep in v if isa(k, Union{Chunk, Thunk}) - println(io, "$(node_id(k)) -> $(node_id(dep))") + println(io, "$(node_name(k)) -> $(node_name(dep))") + end + end + end + + # Argument nodes (not Chunks/Thunks) + argmap = Dict{Int,Vector}() + getargs!(argmap, t) + argids = IdDict{Any,String}() + for id in keys(argmap) + for (argidx,arg) in argmap[id] + name = "arg_$(argidx)_to_$(id)" + if !isimmutable(arg) + if arg in keys(argids) + name = argids[arg] + else + argids[arg] = name + c = write_node(io, arg, c, name) + end + else + c = write_node(io, arg, c, name) end + # Arg-to-compute edges + write_edge(io, name, id) end end end ### Timespan-based graphing -pretty_time(ts::Timespan) = pretty_time(ts.finish-ts.start) -function pretty_time(t) - r(t) = round(t; digits=3) +pretty_time(ts::Timespan; digits::Integer=3) = pretty_time(ts.finish-ts.start; digits=digits) +function pretty_time(t; digits::Integer=3) + r(t) = round(t; digits) if t > 1000^3 "$(r(t/(1000^3))) s" elseif t > 1000^2 @@ -99,47 +183,53 @@ _proc_shape(ctx, proc::Processor) = get!(ctx.proc_to_shape, typeof(proc)) do end _proc_shape(ctx, ::Nothing) = "ellipse" -function write_node(ctx, io, t::Thunk, c) +function write_node(io, t::Thunk, c, ctx=nothing) f = isa(t.f, Function) ? "$(t.f)" : "fn" - println(io, "n_$(t.id) [label=\"$f - $(t.id)\"];") + println(io, "$(node_name(t)) [label=\"$f - $(t.id)\"];") c end dec(x) = Base.dec(x, 0, false) -function write_node(ctx, io, t, c, id=dec(hash(t))) +function write_node(io, t, c, ctx, id=dec(hash(t))) l = replace(node_label(t), "\""=>"") proc = node_proc(t) color = _proc_color(ctx, proc) shape = _proc_shape(ctx, proc) - println(io, "n_$id [label=\"$l\",color=\"$color\",shape=\"$shape\",penwidth=5];") + println(io, "$(node_name(id)) [label=\"$l\",color=\"$color\",shape=\"$shape\",penwidth=5];") c end -function write_node(ctx, io, ts::Timespan, c) +function write_node(io, t, c, name::String) + l = replace(node_label(t), "\""=>"") + println(io, "$(node_name(name)) [label=\"$l\"];") + c +end + +function write_node(io, ts::Timespan, c, ctx; times_digits::Integer=3) (;thunk_id, processor) = ts.id (;f) = ts.timeline f = isa(f, Function) ? "$f" : "fn" - t_comp = pretty_time(ts) + t_comp = pretty_time(ts; digits=times_digits) color = _proc_color(ctx, processor) shape = _proc_shape(ctx, processor) # TODO: t_log = log(ts.finish - ts.start) / 5 ctx.id_to_proc[thunk_id] = processor - println(io, "n_$thunk_id [label=\"$f\n$t_comp\",color=\"$color\",shape=\"$shape\",penwidth=5];") + println(io, "$(node_name(thunk_id)) [label=\"$f\n$t_comp\",color=\"$color\",shape=\"$shape\",penwidth=5];") # TODO: "\n Thunk $(ts.id)\nResult Type: $res_type\nResult Size: $sz_comp\", c end -function write_edge(ctx, io, ts_move::Timespan, logs, inputname=nothing, inputarg=nothing) +function write_edge(io, ts_move::Timespan, logs, ctx, inputname=nothing, inputarg=nothing; times_digits::Integer=3) (;thunk_id, id) = ts_move.id (;f,) = ts_move.timeline - t_move = pretty_time(ts_move) + t_move = pretty_time(ts_move; digits=times_digits) if id > 0 - print(io, "n_$id -> n_$thunk_id [label=\"Move: $t_move") + print(io, "$(node_name(id)) -> $(node_name(thunk_id)) [label=\"Move: $t_move") color_src = _proc_color(ctx, id) else @assert inputname !== nothing @assert inputarg !== nothing - print(io, "n_$inputname -> n_$thunk_id [label=\"Move: $t_move") + print(io, "$(node_name(inputname)) -> $(node_name(thunk_id)) [label=\"Move: $t_move") proc = node_proc(inputarg) color_src = _proc_color(ctx, proc) end @@ -148,15 +238,26 @@ function write_edge(ctx, io, ts_move::Timespan, logs, inputname=nothing, inputar println(io, "\",color=\"$color_src;0.5:$color_dst\",penwidth=2];") end -write_edge(ctx, io, from::String, to::String) = println(io, "n_$from -> n_$to;") +write_edge(io, from::String, to::String, ctx=nothing) = println(io, "$(node_name(from)) -> $(node_name(to));") +write_edge(io, from::String, to::Int, ctx=nothing) = println(io, "$(node_name(from)) -> $(node_name(to));") + +convert_to_thunk(t::Thunk) = t +convert_to_thunk(t::DTask) = Dagger.Sch._find_thunk(t) getargs!(d, t) = nothing + function getargs!(d, t::Thunk) raw_inputs = map(last, t.inputs) d[t.id] = [filter(x->!istask(x[2]), collect(enumerate(raw_inputs)))...,] foreach(i->getargs!(d, i), raw_inputs) end -function write_dag(io, t, logs::Vector) + +function getargs!(d, e::DTask) + getargs!(d, convert_to_thunk(e)) +end + +# DTask is not used in the current implementation, as it would be unstable, and the logs provide all the necessary information +function write_dag(io, logs::Vector, t::Union{Thunk, DTask, Nothing}=nothing; times_digits::Integer=3) ctx = (proc_to_color = Dict{Processor,String}(), proc_colors = Colors.distinguishable_colors(128), proc_color_idx = Ref{Int}(1), @@ -164,46 +265,70 @@ function write_dag(io, t, logs::Vector) proc_shapes = ("ellipse","box","triangle"), proc_shape_idx = Ref{Int}(1), id_to_proc = Dict{Int,Processor}()) - argmap = Dict{Int,Vector}() - getargs!(argmap, t) + c = 1 # Compute nodes for ts in filter(x->x.category==:compute, logs) - c = write_node(ctx, io, ts, c) + c = write_node(io, ts, c, ctx; times_digits=times_digits) end - # Argument nodes - argnodemap = Dict{Int,Vector{String}}() + + # Argument nodes & edges + argmap = Dict{Int,Vector}() argids = IdDict{Any,String}() - for id in keys(argmap) - nodes = String[] - arg_c = 1 - for (argidx,arg) in argmap[id] - name = "arg_$(argidx)_to_$(id)" + if (isa(t, Thunk)) # Then can get info from the Thunk + getargs!(argmap, t) + argnodemap = Dict{Int,Vector{String}}() + for id in keys(argmap) + nodes = String[] + arg_c = 1 + for (argidx,arg) in argmap[id] + name = "arg_$(argidx)_to_$(id)" + if !isimmutable(arg) + if arg in keys(argids) + name = argids[arg] + else + argids[arg] = name + c = write_node(io, arg, c, ctx, name) + end + push!(nodes, name) + else + c = write_node(io, arg, c, ctx, name) + push!(nodes, name) + end + # Arg-to-compute edges + for ts in filter(x->x.category==:move && + x.id.thunk_id==id && + x.id.id==-argidx, logs) + write_edge(io, ts, logs, ctx, name, arg; times_digits=times_digits) + end + arg_c += 1 + end + argnodemap[id] = nodes + end + else # Rely on the logs only + for ts in filter(x->x.category==:move && x.id.id < 0, logs) + (;thunk_id, id) = ts.id + arg = ts.timeline[2] + name = "arg_$(-id)_to_$(thunk_id)" if !isimmutable(arg) if arg in keys(argids) name = argids[arg] else argids[arg] = name - c = write_node(ctx, io, arg, c, name) + c = write_node(io, arg, c, ctx, name) end - push!(nodes, name) else - c = write_node(ctx, io, arg, c, name) - push!(nodes, name) + c = write_node(io, arg, c, ctx, name) end + # Arg-to-compute edges - for ts in filter(x->x.category==:move && - x.id.thunk_id==id && - x.id.id==-argidx, logs) - write_edge(ctx, io, ts, logs, name, arg) - end - arg_c += 1 + write_edge(io, ts, logs, ctx, name, arg; times_digits=times_digits) end - argnodemap[id] = nodes end + # Move edges for ts in filter(x->x.category==:move && x.id.id>0, logs) - write_edge(ctx, io, ts, logs) + write_edge(io, ts, logs, ctx; times_digits=times_digits) end #= FIXME: Legend (currently it's laid out horizontally) println(io, """ @@ -224,21 +349,28 @@ function write_dag(io, t, logs::Vector) =# end -function _show_plan(io::IO, t) +function _show_plan(io::IO, t::Union{Thunk,DTask}) println(io, """strict digraph { graph [layout=dot,rankdir=LR];""") write_dag(io, t) println(io, "}") end -function _show_plan(io::IO, t::Thunk, logs::Vector{Timespan}) +function _show_plan(io::IO, logs::Vector; times_digits::Integer=3) println(io, """strict digraph { graph [layout=dot,rankdir=LR];""") - write_dag(io, t, logs) + write_dag(io, logs; times_digits) println(io, "}") end +function _show_plan(io::IO, t::Union{Thunk,DTask}, logs::Vector{Timespan}; times_digits::Integer=3) + println(io, """strict digraph { + graph [layout=dot,rankdir=LR];""") + write_dag(io, logs, t; times_digits) + println(io, "}/") +end + +Dagger.show_logs(io::IO, t::Union{Thunk,DTask}, ::Val{:graphviz_simple}) = _show_plan(io, t) +Dagger.show_logs(io::IO, logs::Vector{Timespan}, ::Val{:graphviz_simple}; times_digits::Integer=3) = _show_plan(io, logs; times_digits=times_digits) +Dagger.show_logs(io::IO, t::Union{Thunk,DTask}, logs::Vector{Timespan}, ::Val{:graphviz_simple}; times_digits::Integer=3) = _show_plan(io, t, logs; times_digits=times_digits) -show_logs(io::IO, t::Thunk, ::Val{:graphviz_simple}) = _show_plan(io, t) -show_logs(io::IO, logs::Vector{Timespan}, ::Val{:graphviz_simple}) = _show_plan(io, logs) -show_logs(io::IO, t::Thunk, logs::Vector{Timespan}, ::Val{:graphviz_simple}) = _show_plan(io, t, logs) end diff --git a/src/visualization.jl b/src/visualization.jl index 5fe443b5a..0587b264d 100644 --- a/src/visualization.jl +++ b/src/visualization.jl @@ -15,10 +15,10 @@ Returns a string representation of the logs of a task `t` and/or logs object `lo """ function show_logs end -show_logs(io::IO, logs, vizmode::Symbol; options...) = - show_logs(io, logs, Val{vizmode}(); options...) +show_logs(io::IO, arg, vizmode::Symbol; options...) = + show_logs(io, arg, Val{vizmode}(); options...) show_logs(io::IO, t, logs, vizmode::Symbol; options...) = - show_logs(io, t, Val{vizmode}(); options...) + show_logs(io, t, logs, Val{vizmode}(); options...) show_logs(io::IO, ::T, ::Val{vizmode}; options...) where {T,vizmode} = throw(ArgumentError("show_logs: Task/logs type `$T` not supported for visualization mode `$(repr(vizmode))`")) show_logs(io::IO, ::T, ::Logs, ::Val{vizmode}; options...) where {T,Logs,vizmode} =