Skip to content

datadeps: Implement an optimizing scheduler #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
MetricsTracker = "9a9c6fec-044d-4a27-aa18-2b01ca4026eb"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
@@ -33,6 +34,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

@@ -41,6 +43,7 @@ DistributionsExt = "Distributions"
GraphVizExt = "GraphViz"
GraphVizSimpleExt = "Colors"
JSON3Ext = "JSON3"
JuMPExt = "JuMP"
PlotsExt = ["DataFrames", "Plots"]
PythonExt = "PythonCall"

@@ -54,6 +57,7 @@ Distributions = "0.25"
GraphViz = "0.2"
Graphs = "1"
JSON3 = "1"
JuMP = "1"
MacroTools = "0.5"
MemPool = "0.4.11"
OnlineStats = "1"
163 changes: 163 additions & 0 deletions ext/JuMPExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
module JuMPExt

if isdefined(Base, :get_extension)
using JuMP
else
using ..JuMP
end

using Dagger
using Dagger.Distributed
import MetricsTracker as MT
import Graphs: edges, nv, outdegree

struct JuMPScheduler
optimizer
Z::Float64
JuMPScheduler(optimizer) = new(optimizer, 10)
end
function Dagger.datadeps_create_schedule(sched::JuMPScheduler, state, specs_tasks)
astate = state.alias_state
g, task_to_id = astate.g, astate.task_to_id
id_to_task = Dict(id => task for (task, id) in task_to_id)
ntasks = length(specs_tasks)
nprocs = length(state.all_procs)
id_to_proc = Dict(i => p for (i, p) in enumerate(state.all_procs))

# Estimate the time each task will take to execute on each processor,
# and the time it will take to transfer data between processors
task_times = zeros(UInt64, ntasks, nprocs)
xfer_times = zeros(Int, nprocs, nprocs)
lock(MT.GLOBAL_METRICS_CACHE) do cache
for (spec, task) in specs_tasks
id = task_to_id[task]
for p in 1:nprocs
# When searching for a task runtime estimate, we use whatever
# estimate is available and closest if not populated for this processor
# Exact match > same proc type, same node > same proc type, any node > any proc type

sig = Dagger.Sch.signature(spec.f, map(pos_arg->pos_arg[1] => Dagger.unwrap_inout_value(pos_arg[2]), spec.args))
proc = state.all_procs[p]
@warn "Use node, not worker id!" maxlog=1
pid = Dagger.root_worker_id(proc)

# Try exact match
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
MT.LookupExact(Dagger.ProcessorMetric(), proc))
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
if result !== nothing
task_times[id, p] = result
continue
end

# Try same proc type, same node
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
MT.LookupSubtype(Dagger.ProcessorMetric(), typeof(proc)),
MT.LookupCustom(Dagger.ProcessorMetric(), other_proc->Dagger.root_worker_id(other_proc)==pid))
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
if result !== nothing
task_times[id, p] = result
continue
end

# Try same proc type, any node
match_on = (MT.LookupExact(Dagger.SignatureMetric(), sig),
MT.LookupSubtype(Dagger.ProcessorMetric(), typeof(proc)))
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
if result !== nothing
task_times[id, p] = result
continue
end

# Try any signature match
match_on = MT.LookupExact(Dagger.SignatureMetric(), sig)
result = MT.cache_lookup(cache, Dagger, :execute!, MT.TimeMetric(), match_on)::Union{UInt64, Nothing}
if result !== nothing
task_times[id, p] = result
continue
end

# If no information is available, use a random guess
task_times[id, p] = UInt64(rand(1:1_000_000))
end
end

# FIXME: Actually fill this with estimated xfer times
@warn "Assuming all xfer times are 1" maxlog=1
for dst in 1:nprocs
for src in 1:nprocs
if src == dst # FIXME: Or if space is shared
xfer_times[src, dst] = 0
else
# FIXME: sum(currently non-local task arg size) / xfer_speed
xfer_times[src, dst] = 1
end
end
end
end

@warn "If no edges exist, this will fail" maxlog=1
γ = Dict{Tuple{Int, Int}, Matrix{Int}}()
for (i, j) in Tuple.(edges(g))
γ[(i, j)] = copy(xfer_times)
end

a_kls = Tuple.(edges(g))
m = Model(sched.optimizer)
JuMP.set_silent(m)

# Start time of each task
@variable(m, t[1:ntasks] >= 0)
# End time of last task
@variable(m, t_last_end >= 0)

# 1 if task k is assigned to proc p
@variable(m, s[1:ntasks, 1:nprocs], Bin)

# Each task is assigned to exactly one processor
@constraint(m, [k in 1:ntasks], sum(s[k, :]) == 1)

# Penalties for moving between procs
if length(a_kls) > 0
@variable(m, p[a_kls] >= 0)

for (k, l) in a_kls
for p1 in 1:nprocs
for p2 in 1:nprocs
p1 == p2 && continue
# Task l occurs after task k if the procs are different,
# thus there is a penalty
@constraint(m, p[(k, l)] >= (s[k, p1] + s[l, p2] - 1) * γ[(k, l)][p1, p2])
end
end

# Task l occurs after task k
@constraint(m, t[k] + task_times[k, :]' * s[k, :] + p[(k, l)] <= t[l])
end
else
@variable(m, p >= 0)
end

for l in filter(n -> outdegree(g, n) == 0, 1:nv(g))
# DAG ends after the last task
@constraint(m, t[l] + task_times[l, :]' * s[l, :] <= t_last_end)
end

# Minimize the total runtime of the DAG
# TODO: Do we need to bias towards earlier start times?
@objective(m, Min, sched.Z*t_last_end + sum(t) .+ sum(p))

# Solve the model
optimize!(m)

# Extract the schedule from the model
task_to_proc = Dict{DTask, Dagger.Processor}()
for k in 1:ntasks
proc_id = findfirst(identity, value.(s[k, :]) .== 1)
task_to_proc[id_to_task[k]] = id_to_proc[proc_id]
end

return task_to_proc
end

end # module JuMPExt
22 changes: 22 additions & 0 deletions lib/MetricsTracker/LICENSE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
MetricsTracker.jl is licensed under the MIT "Expat" License:

> Copyright (c) 2024: Julian P Samaroo and contributors
>
> Permission is hereby granted, free of charge, to any person obtaining
> a copy of this software and associated documentation files (the
> "Software"), to deal in the Software without restriction, including
> without limitation the rights to use, copy, modify, merge, publish,
> distribute, sublicense, and/or sell copies of the Software, and to
> permit persons to whom the Software is furnished to do so, subject to
> the following conditions:
>
> The above copyright notice and this permission notice shall be
> included in all copies or substantial portions of the Software.
>
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
> EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
> MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
> IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
> CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
> TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
> SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16 changes: 16 additions & 0 deletions lib/MetricsTracker/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name = "MetricsTracker"
uuid = "9a9c6fec-044d-4a27-aa18-2b01ca4026eb"
authors = ["Julian P Samaroo <[email protected]>"]
version = "0.1.0"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"

[compat]
MacroTools = "0.5.13"
ScopedValues = "1.2.1"
Serialization = "1.11.0"
TaskLocalValues = "0.1.1"
17 changes: 17 additions & 0 deletions lib/MetricsTracker/src/MetricsTracker.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module MetricsTracker

import MacroTools: @capture
import ScopedValues: ScopedValue, @with
import TaskLocalValues: TaskLocalValue

include("types.jl")
include("metrics.jl")
include("lookup.jl")
include("io.jl")
include("builtins.jl")
# FIXME
#include("analysis.jl")
#include("aggregate.jl")
#include("decision.jl")

end # module MetricsTracker
17 changes: 17 additions & 0 deletions lib/MetricsTracker/src/aggregate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
abstract type AbstractAggregator <: AbstractAnalysis end

#### Built-in Aggregators ####

struct SimpleAverageAggregator{T} <: AbstractAggregator
inner::T
end
required_metrics(agg::SimpleAverageAggregator, ::Val{context}, ::Val{op}) where {context,op} =
RequiredMetrics((context, op) => [agg.inner])
function run_analysis(agg::SimpleAverageAggregator, ::Val{context}, ::Val{op}, @nospecialize(args...)) where {context,op}
prev = fetch_metric_cached(agg, context, op, args...)
next = fetch_metric(agg.inner, context, op, args...)
if prev === nothing || next === nothing
return next
end
return (prev + next) ÷ 2
end
46 changes: 46 additions & 0 deletions lib/MetricsTracker/src/analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
const RequiredMetrics = Dict{Tuple{Module,Symbol},Vector{AnalysisOrMetric}}
const RequiredMetricsAny = Vector{AnalysisOrMetric}
const NO_REQUIRED_METRICS = RequiredMetrics()
required_metrics(::AnalysisOrMetric, _, _) = NO_REQUIRED_METRICS

function fetch_metric(m::AnalysisOrMetric, mod::Module, context::Symbol, key, extra; cached=false)
@assert !COLLECTING_METRICS[] "Nesting analysis and metrics collection not yet supported"
# Check if this is already cached
cache = local_metrics_cache(mod, context, key)
if cached
return cache[m]
end
# FIXME: Proper invalidation support
if m isa AbstractMetric
if haskey(cache, m)
value = cache[m]
@debug "-- HIT for ($mod, $context) $m [$key] = $value"
return value
else
# The metric isn't available yet
@debug "-- MISS for ($mod, $context) $m [$key]"
return nothing
end
elseif m isa AbstractAnalysis
# Run the analysis
@debug "Running ($mod, $context) $m [$key]"
value = run_analysis(m, Val{nameof(mod)}(), Val{context}(), key, extra)
# TODO: Allocate the correct Dict type
get!(Dict, cache, m)[key] = value
@debug "Finished ($mod, $context) $m [$key] = $value"
return value
end
end

#### Built-in Analyses ####

struct RuntimeWithoutCompilation <: AbstractAnalysis end
required_metrics(::RuntimeWithoutCompilation) =
RequiredMetricsAny([TimeMetric(),
CompileTimeMetric()])
metric_type(::RuntimeWithoutCompilation) = UInt64
function run_analysis(::RuntimeWithoutCompilation, mod, context, key, extra)
time = fetch_metric(TimeMetric(), mod, context, key, extra)
ctime = fetch_metric(CompileTimeMetric(), mod, context, key, extra)
return time - ctime[1]
end
48 changes: 48 additions & 0 deletions lib/MetricsTracker/src/builtins.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#### Built-in Metrics ####

struct TimeMetric <: AbstractMetric end
metric_applies(::TimeMetric, _) = true
metric_type(::TimeMetric) = UInt64
start_metric(::TimeMetric) = time_ns()
stop_metric(::TimeMetric, last::UInt64) = time_ns() - last

struct ThreadTimeMetric <: AbstractMetric end
metric_applies(::ThreadTimeMetric, _) = true
metric_type(::ThreadTimeMetric) = UInt64
start_metric(::ThreadTimeMetric) = cputhreadtime()
stop_metric(::ThreadTimeMetric, last::UInt64) = cputhreadtime() - last

struct CompileTimeMetric <: AbstractMetric end
metric_applies(::CompileTimeMetric, _) = true
metric_type(::CompileTimeMetric) = Tuple{UInt64, UInt64}
function start_metric(::CompileTimeMetric)
Base.cumulative_compile_timing(true)
return Base.cumulative_compile_time_ns()
end
function stop_metric(::CompileTimeMetric, last::Tuple{UInt64, UInt64})
Base.cumulative_compile_timing(false)
return Base.cumulative_compile_time_ns() .- last
end

struct AllocMetric <: AbstractMetric end
metric_applies(::AllocMetric, _) = true
metric_type(::AllocMetric) = Base.GC_Diff
start_metric(::AllocMetric) = Base.gc_num()
stop_metric(::AllocMetric, last::Base.GC_Num) = Base.GC_Diff(Base.gc_num(), last)

struct ResultShapeMetric <: AbstractMetric end
metric_applies(::ResultShapeMetric, _) = true
metric_type(::ResultShapeMetric) = Union{Dims, Nothing}
is_result_metric(::ResultShapeMetric) = true
result_metric(m::ResultShapeMetric, result) =
result isa AbstractArray ? size(result) : nothing

struct LoadAverageMetric <: AbstractMetric end
metric_applies(::LoadAverageMetric, _) = true
metric_type(::LoadAverageMetric) = Tuple{Float64, Float64, Float64}
start_metric(::LoadAverageMetric) = nothing
stop_metric(::LoadAverageMetric, _) = (Sys.loadavg()...,) ./ Sys.CPU_THREADS

# TODO: Useful metrics to add
# perf performance counters
# BPF probe-collected metrics
233 changes: 233 additions & 0 deletions lib/MetricsTracker/src/decision.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
abstract type AbstractDecision end

function required_metrics_to_collect(@nospecialize(m::AbstractDecision), context::Symbol, op::Symbol)
metrics = Dict{Tuple{Symbol,Symbol},Vector{AnalysisOrMetric}}()
to_expand = Tuple{Symbol,Symbol,AnalysisOrMetric}[]
for ((dep_context, dep_op), metrics) in required_metrics_for_all_decisions(m)
append!(to_expand, map(metric->(dep_context, dep_op, metric), metrics))
end

while !isempty(to_expand)
local_context, local_op, metric = pop!(to_expand)
metrics_vec = get!(()->AnalysisOrMetric[], metrics, (local_context, local_op))
if !(metric in metrics_vec)
push!(metrics_vec, metric)
end
for ((dep_context, dep_op), dep_metrics) in required_metrics(metric, Val{local_context}(), Val{local_op}())
append!(to_expand, map(metric->(dep_context, dep_op, metric), dep_metrics))
end
end

if !haskey(metrics, (context, op))
return AbstractMetric[]
end
return filter(dep_m->dep_m isa AbstractMetric, metrics[(context, op)])
end

const DECISION_POINTS = [(:signature, :schedule)]
function required_metrics_for_all_decisions(@nospecialize(m::AbstractDecision))
metrics = Dict{Tuple{Symbol,Symbol},Vector{AnalysisOrMetric}}()
for (context, op) in DECISION_POINTS
decision_metrics = required_metrics(m, Val{context}(), Val{op}())
for ((dec_context, dec_op), dec_metrics) in decision_metrics
append!(get!(Vector{AnalysisOrMetric}, metrics, (dec_context, dec_op)),
dec_metrics)
end
end
return metrics
end

function make_decision(model::AbstractDecision, context::Symbol, op::Symbol, key, args...)
if EAGER_STATE[] === nothing && model != FallbackModel()
# Make default decisions during precompile
return make_decision(FallbackModel(), Val{context}(), Val{op}(), key, args...)
end

@with METRIC_REGION=>(context, op) METRIC_KEY=>Some{Any}(key) begin
return make_decision(model, Val{context}(), Val{op}(), key, args...)
end
end

#### Contexts and Operations ####
# Every metric, analysis, and decision is associated with a combination of a
# "context" and an "operation", where the context is the kind of object being
# operated on, and that object is the key that will be used to later lookup
# that metric's value.

#### Available Contexts and Operations ####
# chunk - Scoped to a Chunk object, cached in core scheduler only
# move - When move() is called on a Chunk
# signature - Scoped to any task with a matching function signature, cached in core scheduler only
# execute - When execute!() is called on a matching task
# schedule - When a matching task is being scheduled
# processor - Scoped to a Processor object, cached in core and worker schedulers
# run - When do_task() is called to run a task
# worker - Scoped to a worker, cached in core and worker schedulers
#
# Currently, all worker-cached metric values are returned to the core when each task completes.

#### Built-in Scheduling Decisions ####

"""
FallbackModel <: AbstractDecision
The default decision model for when precompiling.
"""
struct FallbackModel <: AbstractDecision end

required_metrics(::FallbackModel, ::Val{:signature}, ::Val{:schedule}) =
RequiredMetrics((:signature, :schedule) => [],
(:processor, :run) => [],
(:signature, :execute) => [])

function make_decision(::FallbackModel, ::Val{:signature}, ::Val{:schedule}, signature, inputs, procs)
return procs
end

"""
SchDefaultModel <: AbstractDecision
The scheduler's default decision model. Estimates the cost of scheduling `task`
on each processor in `procs`. Considers current estimated per-processor compute
pressure, and transfer costs for each `Chunk` argument to `task`. Returns
`(procs, costs)`, with `procs` sorted in order of ascending cost.
"""
struct SchDefaultModel <: AbstractDecision end

required_metrics(::SchDefaultModel, ::Val{:signature}, ::Val{:schedule}) =
RequiredMetrics((:signature, :schedule) => [NetworkTransferAnalysis()],
(:processor, :run) => [ProcessorTimePressureMetric()],
(:signature, :execute) => [SimpleAverageAggregator(ThreadTimeMetric()),
SimpleAverageAggregator(ResultSizeMetric())])

function make_decision(::SchDefaultModel, ::Val{:signature}, ::Val{:schedule}, signature::Signature, inputs::Vector{Pair{Union{Symbol,Nothing},Any}}, all_procs::Vector{<:Processor})
# TODO: Unused
# run_cost = something(fetch_metric(SimpleAverageAggregator(ThreadTimeMetric()), :signature, :execute, signature)::Union{UInt64,Nothing}, UInt64(0))

# Estimate total cost for each processor
costs = Dict{Processor,UInt64}()
for proc in all_procs
wait_cost = something(fetch_metric(ProcessorTimePressureMetric(), :processor, :run, proc)::Union{UInt64,Nothing}, UInt64(0))
transfer_cost = something(fetch_metric(NetworkTransferAnalysis(), :signature, :schedule, signature, inputs, proc)::Union{UInt64,Nothing}, UInt64(0))
costs[proc] = wait_cost + transfer_cost
end

# Shuffle procs around, so equally-costly procs are equally considered
P = randperm(length(all_procs))
sorted_procs = getindex.(Ref(all_procs), P)

# Sort by lowest cost first
sort!(sorted_procs, by=p->costs[p])

# Move our corresponding ThreadProc to be the last considered
if length(sorted_procs) > 1
sch_threadproc = Dagger.ThreadProc(myid(), Threads.threadid())
sch_thread_idx = findfirst(==(sch_threadproc), sorted_procs)
if sch_thread_idx !== nothing
deleteat!(sorted_procs, sch_thread_idx)
push!(sorted_procs, sch_threadproc)
end
end

return sorted_procs
end

"The scheduler's basic decision model."
struct SchBasicModel <: AbstractDecision end

mutable struct ProcessorCacheEntry
gproc::OSProc
proc::Processor
next::ProcessorCacheEntry

ProcessorCacheEntry(gproc::OSProc, proc::Processor) = new(gproc, proc)
end
Base.isequal(p1::ProcessorCacheEntry, p2::ProcessorCacheEntry) =
p1.proc === p2.proc
function Base.show(io::IO, entry::ProcessorCacheEntry)
entries = 1
next = entry.next
while next !== entry
entries += 1
next = next.next
end
print(io, "ProcessorCacheEntry(pid $(entry.gproc.pid), $(entry.proc), $entries entries)")
end

function make_decision(::SchBasicModel, ::Val{:signature}, ::Val{:schedule}, signature, inputs, all_procs)
# Populate the cache if empty
# FIXME: Implement cache through SchBasicModel?
procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}()
if state.procs_cache_list[] === nothing
current = nothing
for p in map(x->x.pid, procs)
for proc in get_processors(OSProc(p))
next = ProcessorCacheEntry(OSProc(p), proc)
if current === nothing
current = next
current.next = current
state.procs_cache_list[] = current
else
current.next = next
current = next
current.next = state.procs_cache_list[]
end
end
end
end

# Fast fallback algorithm, useful when the smarter cost model algorithm
# would be too expensive
selected_entry = nothing
entry = state.procs_cache_list[]
cap, extra_util = nothing, nothing
procs_found = false
# N.B. if we only have one processor, we need to select it now
can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope)
if can_use
has_cap, est_time_util, est_alloc_util, est_occupancy =
has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig)
if has_cap
selected_entry = entry
else
procs_found = true
entry = entry.next
end
else
entry = entry.next
end
while selected_entry === nothing
if entry === state.procs_cache_list[]
# Exhausted all procs
if procs_found
push!(failed_scheduling, task)
else
state.cache[task] = SchedulingException("No processors available, try widening scope")
state.errored[task] = true
set_failed!(state, task)
end
return Processor[], Dict()
end

can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope)
if can_use
has_cap, est_time_util, est_alloc_util, est_occupancy =
has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig)
if has_cap
# Select this processor
selected_entry = entry
else
# We could have selected it otherwise
procs_found = true
entry = entry.next
end
else
# Try next processor
entry = entry.next
end
end
@assert selected_entry !== nothing
state.procs_cache_list[] = state.procs_cache_list[].next

return Processor[proc]
end
27 changes: 27 additions & 0 deletions lib/MetricsTracker/src/io.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import Serialization: serialize, deserialize

function load_metrics!(path::String)
loaded_cache = deserialize(path)
global_metrics_cache() do cache
for (mod_context, all_metrics) in loaded_cache
inner_cache = get!(cache, mod_context) do
Dict{Any, Dict{AbstractMetric, Any}}()
end
for (key, keyed_metrics) in all_metrics
inner_keyed_cache = get!(inner_cache, key) do
Dict{AbstractMetric, Any}()
end
for (metric, value) in keyed_metrics
inner_keyed_cache[metric] = value
end
end
end
return cache
end
end
function save_metrics(path::String)
global_metrics_cache() do cache
serialize(path, cache)
return cache
end
end
104 changes: 104 additions & 0 deletions lib/MetricsTracker/src/lookup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
abstract type AbstractLookup end

function lookup_match_metric end
function lookup_match_value end

function cache_lookup(cache::MetricsCache, mod::Module, context::Symbol, key, target_metric::AbstractMetric)
# Check if the cache has the results for this module and context
if !haskey(cache.results, (mod, context))
return nothing
end
inner_cache = cache.results[(mod, context)]

# Check if the cache has the results for this key
if !haskey(inner_cache, key)
return nothing
end
keyed_metrics = inner_cache[key]

# Check if the target metric exists
if !haskey(keyed_metrics, target_metric)
return nothing
end
target_metric_type = metric_type(target_metric)
return keyed_metrics[target_metric]::target_metric_type
end
function cache_lookup(cache::MetricsCache, mod::Module, context::Symbol, target_metric::AbstractMetric, lookup::AbstractLookup)
# Check if the cache has the results for this module and context
if !haskey(cache.results, (mod, context))
return nothing
end
inner_cache = cache.results[(mod, context)]

target_metric_type = metric_type(target_metric)
for (key, keyed_metrics) in inner_cache
for (metric, value) in keyed_metrics
# Check if lookup matches for this key
if lookup_match_metric(lookup, metric) && lookup_match_value(lookup, value)
# Lookup matched, return the target metric if it exists
if !haskey(keyed_metrics, target_metric)
return nothing
end
return keyed_metrics[target_metric]::target_metric_type
end
end
end
return nothing
end
function cache_lookup(cache::MetricsCache, mod::Module, context::Symbol, target_metric::AbstractMetric, lookups::Union{Vector, Tuple})
# Check if the cache has the results for this module and context
if !haskey(cache.results, (mod, context))
return nothing
end
inner_cache = cache.results[(mod, context)]

target_metric_type = metric_type(target_metric)
for (key, keyed_metrics) in inner_cache
# Check if all lookups match for this key
all_lookups_matched = true
for lookup in lookups
lookup_matched = false
for (metric, value) in keyed_metrics
if lookup_match_metric(lookup, metric) && lookup_match_value(lookup, value)
lookup_matched = true
break
end
end
if !lookup_matched
all_lookups_matched = false
break
end
end

if all_lookups_matched
# All lookups matched, return the target metric if it exists
if !haskey(keyed_metrics, target_metric)
return nothing
end
return keyed_metrics[target_metric]::target_metric_type
end
end
return nothing
end

struct LookupExact{M<:AbstractMetric,T} <: AbstractLookup
metric::M
target::T
end
lookup_match_metric(l::LookupExact, metric::AbstractMetric) = l.metric == metric
lookup_match_value(l::LookupExact{M,T}, value::T) where {M,T} = l.target == value

struct LookupSubtype{M<:AbstractMetric,T} <: AbstractLookup
metric::M
supertype::Type{T}
end
lookup_match_metric(l::LookupSubtype, metric::AbstractMetric) = l.metric == metric
lookup_match_value(l::LookupSubtype{M,T}, ::Type{T}) where {M,T} = true
lookup_match_value(l::LookupSubtype{M,T1}, ::Type{T2}) where {M,T1,T2} = false

struct LookupCustom{M<:AbstractMetric,F} <: AbstractLookup
metric::M
func::F
end
lookup_match_metric(l::LookupCustom, metric::AbstractMetric) = l.metric == metric
lookup_match_value(l::LookupCustom{M,F}, value) where {M,F} = l.func(value)
176 changes: 176 additions & 0 deletions lib/MetricsTracker/src/metrics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
export @with_metrics, with_metrics

const METRIC_REGION = ScopedValue{Union{Tuple{Symbol,Symbol},Nothing}}(nothing)
metric_region() = METRIC_REGION[]
const METRIC_KEY = ScopedValue{Union{Some{Any},Nothing}}(nothing)
metric_key() = something(METRIC_KEY[])

metric_applies(::AbstractMetric, _) = false
is_result_metric(::AbstractMetric) = false

function with_metrics(f, ms::MetricsSpec, mod::Module, context::Symbol, key, sync_loc::SyncLocation)
@assert !COLLECTING_METRICS[] "Nested metrics collection not yet supported"
# TODO: Early filter out non-applicable metrics
cross_metric_values = ntuple(length(ms.metrics)) do i
m = ms.metrics[i]
if metric_applies(m, Val{context}()) && !is_result_metric(m)
return start_metric(m)
else
return nothing
end
end

@debug "Starting metrics collection for ($mod, $context) [$key]"
result = nothing
try
return @with COLLECTING_METRICS=>true f()
finally
@debug "Finished metrics collection for ($mod, $context) [$key]"

final_metric_values = reverse(ntuple(length(ms.metrics)) do i
m = ms.metrics[length(ms.metrics) - i + 1]
if metric_applies(m, Val{context}())
if is_result_metric(m) && result !== nothing
return result_metric(m, something(result))
else
return stop_metric(m, cross_metric_values[length(ms.metrics) - i + 1])
end
else
return nothing
end
end)
set_metric_values!(ms, mod, context, key, sync_loc, final_metric_values)
end
end
function with_metrics(f, ms::Tuple, mod::Module, context::Symbol, sync_loc::SyncLocation, key)
with_metrics(f, MetricsSpec(ms...), mod, context, key)
end
macro with_metrics(ms, context, key, sync_loc, body)
esc(quote
$with_metrics(() -> $body, $ms, $__module__, $context, $key, $sync_loc)
end)
end
macro with_metrics(ms, mod, context, key, sync_loc, body)
esc(quote
$with_metrics(() -> $body, $ms, $mod, $context, $key, $sync_loc)
end)
end

function (wm::WithMetrics)(f, args...; kwargs...)
return @with_metrics wm.spec wm.context wm.key wm.sync_loc f(args...; kwargs...)
end

Base.getindex(cache::MetricsCache, mod_context::Tuple{Module, Symbol}) =
getindex(cache.results, mod_context)
Base.setindex!(cache::MetricsCache, value, mod_context::Tuple{Module, Symbol}) =
setindex!(cache.results, value, mod_context)
Base.get(cache::MetricsCache, mod_context::Tuple{Module, Symbol}, default) =
get(cache.results, mod_context, default)
Base.iterate(cache::MetricsCache) = iterate(cache.results)
Base.iterate(cache::MetricsCache, state) = iterate(cache.results, state)
Base.length(cache::MetricsCache) = length(cache.results)
function Base.show(io::IO, ::MIME"text/plain", cache::MetricsCache)
println("MetricsCache:")
for ((mod, context), metrics) in cache.results
println(io, " Metrics for ($mod, $context):")
for (key, values) in metrics
println(io, " Key: $key")
for (metric, value) in values
println(io, " $metric: $value")
end
end
end
end

# TODO: Add recursive tracking?
const LOCAL_METRICS_CACHE = TaskLocalValue{MetricsCache}(()->MetricsCache())
local_metrics_cache() = LOCAL_METRICS_CACHE[]
local_metrics_cache(mod::Module, context::Symbol) =
get!(local_metrics_cache(), (mod, context)) do
Dict{Any, Dict{AbstractMetric, Any}}()
end
local_metrics_cache(mod::Module, context::Symbol, key) =
get!(local_metrics_cache(mod, context), key) do
Dict{AbstractMetric, Any}()
end

const GLOBAL_METRICS_CACHE = Base.Lockable(MetricsCache())
global_metrics_cache(f) = lock(f, GLOBAL_METRICS_CACHE)
global_metrics_cache(f, mod::Module, context::Symbol, key) = global_metrics_cache() do cache
inner_cache = get!(get!(cache, (mod, context)) do
Dict{Any, Dict{AbstractMetric, Any}}()
end, key) do
Dict{AbstractMetric, Any}()
end
return f(inner_cache)
end

function set_metric_values!(ms::MetricsSpec,
mod::Module, context::Symbol,
key,
::SyncTask,
values::Tuple)
cache = local_metrics_cache(mod, context, key)
sync_results_into!(cache, ms, values)
return
end
function set_metric_values!(ms::MetricsSpec,
mod::Module, context::Symbol,
key,
::SyncGlobal,
values::Tuple)
global_metrics_cache(mod, context, key) do cache
sync_results_into!(cache, ms, values)
end
return
end
function set_metric_values!(ms::MetricsSpec,
mod::Module, context::Symbol,
key,
sync_loc::SyncInto,
values::Tuple)
sync_results_into!(sync_loc.cache, ms, mod, context, key, values)
return
end

function sync_results_into!(cache::MetricsCache,
ms::MetricsSpec,
mod::Module,
context::Symbol,
key,
values::Tuple)
inner_cache = get!(cache.results, (mod, context)) do
Dict{Any, Dict{AbstractMetric, Any}}()
end
keyed_cache = get!(inner_cache, key) do
Dict{AbstractMetric, Any}()
end
sync_results_into!(keyed_cache, ms, values)
return
end
function sync_results_into!(cache::Dict{AbstractMetric, Any},
ms::MetricsSpec,
values::Tuple)
ntuple(length(ms.metrics)) do i
m = ms.metrics[i]
cache[m] = values[i]
return
end
return
end
function sync_results_into!(dest_cache::MetricsCache, src_cache::MetricsCache)
for ((mod, context), metrics) in src_cache.results
dest_inner_cache = get!(dest_cache.results, (mod, context)) do
Dict{Any, Dict{AbstractMetric, Any}}()
end
for (key, keyed_metrics) in metrics
dest_keyed_metrics = get!(dest_inner_cache, key) do
Dict{AbstractMetric, Any}()
end
for (metric, value) in keyed_metrics
dest_keyed_metrics[metric] = value
end
end
end
return
end
33 changes: 33 additions & 0 deletions lib/MetricsTracker/src/types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
abstract type AbstractMetric end

struct MetricsSpec{M<:Tuple}
metrics::M

function MetricsSpec(m::Vararg{AbstractMetric})
return new{typeof(m)}(m)
end
end

# TODO: Should we flip the order so that each metric has its own type-stable cache?
struct MetricsCache <: AbstractDict{Tuple{Module, Symbol}, Dict{Any, Dict{AbstractMetric, Any}}}
results::Dict{Tuple{Module, Symbol}, Dict{Any, Dict{AbstractMetric, Any}}}

MetricsCache() =
new(Dict{Tuple{Module, Symbol}, Dict{Any, Dict{AbstractMetric, Any}}}())
end

abstract type SyncLocation end
struct SyncTask <: SyncLocation end
struct SyncGlobal <: SyncLocation end
struct SyncInto <: SyncLocation
cache::MetricsCache
end

struct WithMetrics{MS<:MetricsSpec, C, K, S<:SyncLocation}
spec::MS
context::C
key::K
sync_loc::S
end

const COLLECTING_METRICS = ScopedValue{Bool}(false)
51 changes: 51 additions & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
@@ -43,6 +43,10 @@ else
import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remotecall, remotecall_wait, remotecall_fetch, check_same_host
end

import MetricsTracker as MT
const reuse_metrics = @load_preference("reuse-metrics", false)
const metrics_path = @load_preference("metrics-path", "metrics.json")

include("lib/util.jl")
include("utils/dagdebug.jl")

@@ -67,6 +71,9 @@ include("submission.jl")
include("chunks.jl")
include("memory-spaces.jl")

# Metrics
include("utils/metrics.jl")

# Task scheduling
include("compute.jl")
include("utils/clock.jl")
@@ -126,6 +133,30 @@ function set_distributed_package!(value)
@info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!"
end

"""
set_reuse_metrics!(value::Bool)
Set a [preference](https://github.com/JuliaPackaging/Preferences.jl) for
enabling or disabling the reuse of collected metrics across Julia sessions.
You will need to restart Julia after setting a new preference.
"""
function set_reuse_metrics!(value::Bool)
@set_preferences!("reuse-metrics" => value)
@info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!"
end

"""
set_metrics_path!(value::String)
Set a [preference](https://github.com/JuliaPackaging/Preferences.jl) for
the path to save and load metrics. You will need to restart Julia after setting
a new preference.
"""
function set_metrics_path!(value::String)
@set_preferences!("metrics-path" => value)
@info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!"
end

# Precompilation
import PrecompileTools: @compile_workload
include("precompile.jl")
@@ -189,6 +220,26 @@ function __init__()
catch err
@warn "Error parsing JULIA_DAGGER_DEBUG" exception=err
end

if reuse_metrics
if isfile(metrics_path)
# Load metrics
@dagdebug nothing :metrics "Loading metrics"
try
MT.load_metrics!(metrics_path)
catch err
@warn "Error loading metrics" exception=(err, catch_backtrace())
end
else
@dagdebug nothing :metrics "Metrics file not found"
end

atexit() do
# Save metrics on exit
@dagdebug nothing :metrics "Saving metrics"
MT.save_metrics(metrics_path)
end
end
end

end # module
555 changes: 306 additions & 249 deletions src/datadeps.jl

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions src/memory-spaces.jl
Original file line number Diff line number Diff line change
@@ -124,8 +124,10 @@ memory_spans(x, T) = memory_spans(aliasing(x, T))

struct NoAliasing <: AbstractAliasing end
memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
equivalent_structure(::NoAliasing, ::NoAliasing) = true
struct UnknownAliasing <: AbstractAliasing end
memory_spans(::UnknownAliasing) = [MemorySpan{CPURAMMemorySpace}(C_NULL, typemax(UInt))]
equivalent_structure(::UnknownAliasing, ::UnknownAliasing) = true

warn_unknown_aliasing(T) =
@warn "Cannot resolve aliasing for object of type $T\nExecution may become sequential"
@@ -141,6 +143,18 @@ function memory_spans(ca::CombinedAliasing)
end
return all_spans
end
function equivalent_structure(ainfo1::CombinedAliasing,
ainfo2::CombinedAliasing)
for sub_ainfo1 in ainfo1.sub_ainfos
for sub_ainfo2 in ainfo2.sub_ainfos
if equivalent_structure(sub_ainfo1, sub_ainfo2)
break
end
end
return false
end
return true
end
Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) =
ca1.sub_ainfos == ca2.sub_ainfos
Base.hash(ca1::CombinedAliasing, h::UInt) =
@@ -161,6 +175,10 @@ function memory_spans(oa::ObjectAliasing)
span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz)
return [span]
end
function equivalent_structure(ainfo1::ObjectAliasing,
ainfo2::ObjectAliasing)
return ainfo1.sz == ainfo2.sz
end

aliasing(x, T) = aliasing(T(x))
function aliasing(x::T) where T
@@ -221,6 +239,10 @@ function aliasing(x::Array{T}) where T
end
aliasing(x::Transpose) = aliasing(parent(x))
aliasing(x::Adjoint) = aliasing(parent(x))
function equivalent_structure(ainfo1::ContiguousAliasing{S},
ainfo2::ContiguousAliasing{S}) where {S}
return ainfo1.span.len == ainfo2.span.len
end

struct StridedAliasing{T,N,S} <: AbstractAliasing
base_ptr::RemotePtr{Cvoid,S}
@@ -279,6 +301,12 @@ function will_alias(x::StridedAliasing{T,N,S}, y::StridedAliasing{T,N,S}) where
return true
end
# FIXME: Upgrade Contiguous/StridedAlising to same number of dims
function equivalent_structure(ainfo1::StridedAliasing{T,N,S},
ainfo2::StridedAliasing{T,N,S}) where {T,N,S}
return ainfo1.base_inds == ainfo2.base_inds &&
ainfo1.lengths == ainfo2.lengths &&
ainfo1.strides == ainfo2.strides
end

struct TriangularAliasing{T,S} <: AbstractAliasing
ptr::RemotePtr{Cvoid,S}
@@ -311,6 +339,12 @@ aliasing(x::UnitUpperTriangular{T}) where T =
TriangularAliasing{T,CPURAMMemorySpace}(pointer(parent(x)), size(parent(x), 1), true, false)
aliasing(x::UnitLowerTriangular{T}) where T =
TriangularAliasing{T,CPURAMMemorySpace}(pointer(parent(x)), size(parent(x), 1), false, false)
function equivalent_structure(ainfo1::TriangularAliasing{T,S},
ainfo2::TriangularAliasing{T,S}) where {T,S}
return ainfo1.stride == ainfo2.stride &&
ainfo1.isupper == ainfo2.isupper &&
ainfo1.diagonal == ainfo2.diagonal
end

struct DiagonalAliasing{T,S} <: AbstractAliasing
ptr::RemotePtr{Cvoid,S}
@@ -331,6 +365,10 @@ function aliasing(x::AbstractMatrix{T}, ::Type{Diagonal}) where T
rptr = RemotePtr{Cvoid}(ptr, S)
return DiagonalAliasing{T,typeof(S)}(rptr, size(parent(x), 1))
end
function equivalent_structure(ainfo1::DiagonalAliasing{T,S},
ainfo2::DiagonalAliasing{T,S}) where {T,S}
return ainfo1.stride == ainfo2.stride
end
# FIXME: Bidiagonal
# FIXME: Tridiagonal

@@ -368,3 +406,11 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan)
y_end = y_span.ptr + y_span.len - 1
return x_span.ptr <= y_end && y_span.ptr <= x_end
end

"""
equivalent_structure(ainfo1::AbstractAliasing, ainfo2::AbstractAliasing) -> Bool
Returns `true` when `ainfo1` and `ainfo2` represent objects with the same
memory structure, ignoring the specific memory addresses; otherwise, `false`.
"""
equivalent_structure(ainfo1::AbstractAliasing, ainfo2::AbstractAliasing) = false
61 changes: 45 additions & 16 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
@@ -15,9 +15,12 @@ import Base: @invokelatest

import ..Dagger
import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, LockedObject
import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime
import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain
import ..Dagger: @dagdebug, @safe_lock_spin1
import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek
import ScopedValues: @with

import MetricsTracker as MT

import ..Dagger

@@ -70,8 +73,8 @@ Fields:
- `worker_loadavg::Dict{Int,NTuple{3,Float64}}` - Worker load average
- `worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}` - Communication channels between the scheduler and each worker
- `procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}` - Cached linked list of processors ready to be used
- `signature_time_cost::Dict{Signature,UInt64}` - Cache of estimated CPU time (in nanoseconds) required to compute calls with the given signature
- `signature_alloc_cost::Dict{Signature,UInt64}` - Cache of estimated CPU RAM (in bytes) required to compute calls with the given signature
- `signature_time_cost::Dict{Signature,Dict{Processor,UInt64}}` - Cache of estimated CPU time (in nanoseconds) required to compute calls with the given signature on a given processor
- `signature_alloc_cost::Dict{Signature,Dict{Processor,UInt64}}` - Cache of estimated CPU RAM (in bytes) required to compute calls with the given signature on a given processor
- `transfer_rate::Ref{UInt64}` - Estimate of the network transfer rate in bytes per second
- `halt::Base.Event` - Event indicating that the scheduler is halting
- `lock::ReentrantLock` - Lock around operations which modify the state
@@ -97,8 +100,8 @@ struct ComputeState
worker_loadavg::Dict{Int,NTuple{3,Float64}}
worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}
procs_cache_list::Base.RefValue{Union{ProcessorCacheEntry,Nothing}}
signature_time_cost::Dict{Signature,UInt64}
signature_alloc_cost::Dict{Signature,UInt64}
signature_time_cost::Dict{Signature,Dict{Processor,UInt64}}
signature_alloc_cost::Dict{Signature,Dict{Processor,UInt64}}
transfer_rate::Ref{UInt64}
halt::Base.Event
lock::ReentrantLock
@@ -127,8 +130,8 @@ function start_state(deps::Dict, node_order, chan)
Dict{Int,NTuple{3,Float64}}(),
Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(),
Ref{Union{ProcessorCacheEntry,Nothing}}(nothing),
Dict{Signature,UInt64}(),
Dict{Signature,UInt64}(),
Dict{Signature,Dict{Processor,UInt64}}(),
Dict{Signature,Dict{Processor,UInt64}}(),
Ref{UInt64}(1_000_000),
Base.Event(),
ReentrantLock(),
@@ -587,16 +590,23 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options)
end
node = unwrap_weak_checked(state.thunk_dict[thunk_id])
if metadata !== nothing
# Update metrics
state.worker_time_pressure[pid][proc] = metadata.time_pressure
#to_storage = fetch(node.options.storage)
#state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure
#state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity
state.worker_loadavg[pid] = metadata.loadavg

sig = signature(state, node)
state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2
state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2
time_costs_proc = get!(Dict{Processor,UInt64}, state.signature_time_cost, sig)
time_cost = get(time_costs_proc, proc, UInt64(0))
time_costs_proc[proc] = (metadata.threadtime + time_cost) ÷ UInt64(2)
alloc_costs_proc = get!(Dict{Processor,UInt64}, state.signature_alloc_cost, sig)
alloc_cost = get(alloc_costs_proc, proc, UInt64(0))
alloc_costs_proc[proc] = (metadata.gc_allocd + alloc_cost) ÷ UInt64(2)

if metadata.transfer_rate !== nothing
state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ 2
state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ UInt64(2)
end
end
state.cache[node] = res
@@ -1648,6 +1658,12 @@ function do_task(to_proc, task_desc)
end
end

# Compute signature
@warn "Fix kwargs" maxlog=1
sig = DataType[Tf, map(fetched_args) do x
chunktype(x)
end...]

#= FIXME: If MaxUtilization, stop processors and wait
if (est_time_util isa MaxUtilization) && (real_time_util > 0)
# FIXME: Stop processors
@@ -1660,8 +1676,11 @@ function do_task(to_proc, task_desc)
timespan_start(ctx, :compute, (;thunk_id, processor=to_proc), (;f))
res = nothing

# Start counting time and GC allocations
threadtime_start = cputhreadtime()
# Setup metrics for time monitoring
mspec = MT.MetricsSpec(MT.TimeMetric(), Dagger.SignatureMetric(), Dagger.ProcessorMetric())
local_cache = MT.MetricsCache()

# Start counting GC allocations
# FIXME
#gcnum_start = Base.gc_num()

@@ -1677,9 +1696,13 @@ function do_task(to_proc, task_desc)
cancel_token=Dagger.DTASK_CANCEL_TOKEN[],
))

# Execute
res = Dagger.with_options(propagated) do
# Execute
execute!(to_proc, f, fetched_args...; fetched_kwargs...)
@with Dagger.TASK_SIGNATURE=>sig Dagger.TASK_PROCESSOR=>to_proc begin
MT.@with_metrics mspec Dagger :execute! thunk_id MT.SyncInto(local_cache) begin
execute!(to_proc, f, fetched_args...; fetched_kwargs...)
end
end
end

# Check if result is safe to store
@@ -1705,10 +1728,16 @@ function do_task(to_proc, task_desc)
RemoteException(myid(), CapturedException(ex, bt))
end

threadtime = cputhreadtime() - threadtime_start
lock(MT.GLOBAL_METRICS_CACHE) do global_cache
MT.sync_results_into!(global_cache, local_cache)
end

# FIXME: This is not a realistic measure of max. required memory
#gc_allocd = min(max(UInt64(Base.gc_num().allocd) - UInt64(gcnum_start.allocd), UInt64(0)), UInt64(1024^4))
timespan_finish(ctx, :compute, (;thunk_id, processor=to_proc), (;f, result=result_meta))

threadtime = MT.cache_lookup(local_cache, Dagger, :execute!, thunk_id, MT.TimeMetric())

lock(TASK_SYNC) do
real_time_util[] -= est_time_util
pop!(TASKS_RUNNING, thunk_id)
@@ -1723,7 +1752,7 @@ function do_task(to_proc, task_desc)
storage_pressure=real_alloc_util,
storage_capacity=storage_cap,
loadavg=((Sys.loadavg()...,) ./ Sys.CPU_THREADS),
threadtime=threadtime,
threadtime,
# FIXME: Add runtime allocation tracking
gc_allocd=(isa(result_meta, Chunk) ? result_meta.handle.size : 0),
transfer_rate=(transfer_size[] > 0 && transfer_time[] > 0) ? round(UInt64, transfer_size[] / (transfer_time[] / 10^9)) : nothing,
26 changes: 18 additions & 8 deletions src/sch/util.jl
Original file line number Diff line number Diff line change
@@ -325,6 +325,7 @@ function signature(f, args)
end
return sig
end
signature(spec::Dagger.DTaskSpec) = signature(spec.f, spec.args)

function can_use_proc(state, task, gproc, proc, opts, scope)
# Check against proclist
@@ -399,17 +400,24 @@ end

function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig)
T = typeof(p)
# FIXME: MaxUtilization
est_time_util = round(UInt64, if time_util !== nothing && haskey(time_util, T)
time_util[T] * 1000^3

@warn "Use special lookup to use other proc estimates" maxlog=1
est_time_util = if time_util !== nothing && haskey(time_util, T)
round(UInt64, time_util[T] * 1000^3)
elseif haskey(state.signature_time_cost, sig) && haskey(state.signature_time_cost[sig], p)
state.signature_time_cost[sig][p]
else
get(state.signature_time_cost, sig, 1000^3)
end)
UInt64(1000^3)
end

est_alloc_util = if alloc_util !== nothing && haskey(alloc_util, T)
alloc_util[T]
alloc_util[T]::UInt64
elseif haskey(state.signature_alloc_cost, sig) && haskey(state.signature_alloc_cost[sig], p)
state.signature_alloc_cost[sig][p]
else
get(state.signature_alloc_cost, sig, UInt64(0))
end::UInt64
UInt64(0)
end

est_occupancy::UInt32 = typemax(UInt32)
if occupancy !== nothing
occ = nothing
@@ -423,6 +431,7 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig)
est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32))
end
end

#= FIXME: Estimate if cached data can be swapped to storage
storage = storage_resource(p)
real_alloc_util = state.worker_storage_pressure[gp][storage]
@@ -431,6 +440,7 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig)
return false, est_time_util, est_alloc_util
end
=#

return true, est_time_util, est_alloc_util, est_occupancy
end

5 changes: 5 additions & 0 deletions src/submission.jl
Original file line number Diff line number Diff line change
@@ -76,13 +76,15 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
end
for (idx, dep) in enumerate(syncdeps)
newdep = if dep isa DTask
@assert dep.uid != uid "Cannot depend on self"
tid = if haskey(id_map, dep.uid)
id_map[dep.uid]
else
uid_to_tid[dep.uid]
end
state.thunk_dict[tid]
elseif dep isa Sch.ThunkID
@assert dep.id != id "Cannot depend on self"
tid = dep.id
state.thunk_dict[tid]
else
@@ -240,6 +242,8 @@ end
chunktype(t::DTask) = t.metadata.return_type

function eager_launch!((spec, task)::Pair{DTaskSpec,DTask})
@assert !istaskstarted(task) "Cannot launch a task that is already started"

# Assign a name, if specified
eager_assign_name!(spec, task)

@@ -261,6 +265,7 @@ function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}})

# Assign a name, if specified
for (spec, task) in specs
@assert !istaskstarted(task) "Cannot launch a task that is already started"
eager_assign_name!(spec, task)
end

15 changes: 15 additions & 0 deletions src/utils/metrics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
const TASK_SIGNATURE = ScopedValue{Union{Vector{DataType}, Nothing}}(nothing)
struct SignatureMetric <: MT.AbstractMetric end
MT.metric_applies(::SignatureMetric, ::Val{:execute!}) = true
MT.metric_type(::SignatureMetric) = Union{Vector{DataType}, Nothing}
MT.start_metric(::SignatureMetric) = nothing
MT.stop_metric(::SignatureMetric, _) = TASK_SIGNATURE[]

const TASK_PROCESSOR = ScopedValue{Union{Processor, Nothing}}(nothing)
struct ProcessorMetric <: MT.AbstractMetric end
MT.metric_applies(::ProcessorMetric, ::Val{:execute!}) = true
MT.metric_type(::ProcessorMetric) = Union{Processor, Nothing}
MT.start_metric(::ProcessorMetric) = nothing
MT.stop_metric(::ProcessorMetric, _) = TASK_PROCESSOR[]

# FIXME: struct TransferTimeMetric <: MT.AbstractMetric end