Skip to content

Commit a14d265

Browse files
committed
Test: Make test recording thread safe
In preparation of #53462, ensure that attempting to record test results to a test set from multiple threads does not cause corruption. Note that other part of `Test` remain non-threadsafe.
1 parent d28a587 commit a14d265

File tree

3 files changed

+68
-45
lines changed

3 files changed

+68
-45
lines changed

stdlib/Test/src/Test.jl

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,17 +1208,27 @@ are any `Fail`s or `Error`s, an exception will be thrown only at the end,
12081208
along with a summary of the test results.
12091209
"""
12101210
mutable struct DefaultTestSet <: AbstractTestSet
1211-
description::String
1212-
results::Vector{Any}
1213-
n_passed::Int
1214-
anynonpass::Bool
1215-
verbose::Bool
1216-
showtiming::Bool
1217-
time_start::Float64
1218-
time_end::Union{Float64,Nothing}
1219-
failfast::Bool
1220-
file::Union{String,Nothing}
1211+
const description::String
1212+
const verbose::Bool
1213+
const showtiming::Bool
1214+
const failfast::Bool
1215+
const file::Union{String,Nothing}
1216+
const time_start::Float64
1217+
1218+
# Warning: Not thread-safe
12211219
rng::Union{Nothing,AbstractRNG}
1220+
1221+
@atomic n_passed::Int
1222+
@atomic time_end::Float64
1223+
1224+
# Memoized test result state over `results` - Computed only once the test set is finished
1225+
# 0x0: Unknown
1226+
# 0x1: All passed
1227+
# 0x2: Some failed
1228+
@atomic anynonpass::UInt8
1229+
1230+
results_lock::ReentrantLock
1231+
results::Vector{Any}
12221232
end
12231233
function DefaultTestSet(desc::AbstractString; verbose::Bool = false, showtiming::Bool = true, failfast::Union{Nothing,Bool} = nothing, source = nothing, rng = nothing)
12241234
if isnothing(failfast)
@@ -1230,7 +1240,9 @@ function DefaultTestSet(desc::AbstractString; verbose::Bool = false, showtiming:
12301240
failfast = false
12311241
end
12321242
end
1233-
return DefaultTestSet(String(desc)::String, [], 0, false, verbose, showtiming, time(), nothing, failfast, extract_file(source), rng)
1243+
return DefaultTestSet(String(desc)::String,
1244+
verbose, showtiming, failfast, extract_file(source),
1245+
time(), rng, 0, 0., 0x00, ReentrantLock(), Any[])
12341246
end
12351247
extract_file(source::LineNumberNode) = extract_file(source.file)
12361248
extract_file(file::Symbol) = string(file)
@@ -1239,15 +1251,15 @@ extract_file(::Nothing) = nothing
12391251
struct FailFastError <: Exception end
12401252

12411253
# For a broken result, simply store the result
1242-
record(ts::DefaultTestSet, t::Broken) = (push!(ts.results, t); t)
1254+
record(ts::DefaultTestSet, t::Broken) = ((@lock ts.results_lock push!(ts.results, t)); t)
12431255
# For a passed result, do not store the result since it uses a lot of memory, unless
12441256
# `record_passes()` is true. i.e. set env var `JULIA_TEST_RECORD_PASSES=true` before running any testsets
12451257
function record(ts::DefaultTestSet, t::Pass)
1246-
ts.n_passed += 1
1258+
@atomic :monotonic ts.n_passed += 1
12471259
if record_passes()
12481260
# throw away the captured data so it can be GC-ed
12491261
t_nodata = Pass(t.test_type, t.orig_expr, nothing, t.value, t.source, t.message_only)
1250-
push!(ts.results, t_nodata)
1262+
@lock ts.results_lock push!(ts.results, t_nodata)
12511263
return t_nodata
12521264
end
12531265
return t
@@ -1268,7 +1280,7 @@ function record(ts::DefaultTestSet, t::Union{Fail, Error}; print_result::Bool=TE
12681280
println()
12691281
end
12701282
end
1271-
push!(ts.results, t)
1283+
@lock ts.results_lock push!(ts.results, t)
12721284
(FAIL_FAST[] || ts.failfast) && throw(FailFastError())
12731285
return t
12741286
end
@@ -1297,7 +1309,7 @@ results(ts::DefaultTestSet) = ts.results
12971309
# When a DefaultTestSet finishes, it records itself to its parent
12981310
# testset, if there is one. This allows for recursive printing of
12991311
# the results at the end of the tests
1300-
record(ts::DefaultTestSet, t::AbstractTestSet) = push!(ts.results, t)
1312+
record(ts::DefaultTestSet, t::AbstractTestSet) = @lock ts.results_lock push!(ts.results, t)
13011313

13021314
@specialize
13031315

@@ -1402,7 +1414,9 @@ const TESTSET_PRINT_ENABLE = Ref(true)
14021414
# Called at the end of a @testset, behaviour depends on whether
14031415
# this is a child of another testset, or the "root" testset
14041416
function finish(ts::DefaultTestSet; print_results::Bool=TESTSET_PRINT_ENABLE[])
1405-
ts.time_end = time()
1417+
if (@atomicswap ts.time_end = time()) !== 0.
1418+
error("Test set was finished more than once")
1419+
end
14061420
# If we are a nested test set, do not print a full summary
14071421
# now - let the parent test set do the printing
14081422
if get_testset_depth() != 0
@@ -1433,24 +1447,6 @@ function finish(ts::DefaultTestSet; print_results::Bool=TESTSET_PRINT_ENABLE[])
14331447
return ts
14341448
end
14351449

1436-
# Recursive function that finds the column that the result counts
1437-
# can begin at by taking into account the width of the descriptions
1438-
# and the amount of indentation. If a test set had no failures, and
1439-
# no failures in child test sets, there is no need to include those
1440-
# in calculating the alignment
1441-
function get_alignment(ts::DefaultTestSet, depth::Int)
1442-
# The minimum width at this depth is
1443-
ts_width = 2*depth + length(ts.description)
1444-
# If not verbose and all passing, no need to look at children
1445-
!ts.verbose && !ts.anynonpass && return ts_width
1446-
# Return the maximum of this width and the minimum width
1447-
# for all children (if they exist)
1448-
isempty(ts.results) && return ts_width
1449-
child_widths = map(t->get_alignment(t, depth+1), ts.results)
1450-
return max(ts_width, maximum(child_widths))
1451-
end
1452-
get_alignment(ts, depth::Int) = 0
1453-
14541450
# Recursive function that fetches backtraces for any and all errors
14551451
# or failures the testset and its children encountered
14561452
function filter_errors(ts::DefaultTestSet)
@@ -1536,7 +1532,7 @@ function get_test_counts(ts::DefaultTestSet)
15361532
passes, fails, errors, broken = ts.n_passed, 0, 0, 0
15371533
# cumulative results
15381534
c_passes, c_fails, c_errors, c_broken = 0, 0, 0, 0
1539-
for t in ts.results
1535+
@lock ts.results_lock for t in ts.results
15401536
isa(t, Fail) && (fails += 1)
15411537
isa(t, Error) && (errors += 1)
15421538
isa(t, Broken) && (broken += 1)
@@ -1549,10 +1545,37 @@ function get_test_counts(ts::DefaultTestSet)
15491545
end
15501546
end
15511547
duration = format_duration(ts)
1552-
ts.anynonpass = (fails + errors + c_fails + c_errors > 0)
1553-
return TestCounts(true, passes, fails, errors, broken, c_passes, c_fails, c_errors, c_broken, duration)
1548+
tc = TestCounts(true, passes, fails, errors, broken, c_passes, c_fails, c_errors, c_broken, duration)
1549+
# Memoize for printing convenience
1550+
@atomic :monotonic ts.anynonpass = (anynonpass(tc) ? 0x02 : 0x01)
1551+
return tc
1552+
end
1553+
anynonpass(tc::TestCounts) = (tc.fails + tc.errors + tc.cumulative_fails + tc.cumulative_errors > 0)
1554+
function anynonpass(ts::DefaultTestSet)
1555+
if (@atomic :monotonic ts.anynonpass) == 0x00
1556+
get_test_counts(ts) # fills in the anynonpass field
1557+
end
1558+
return (@atomic :monotonic ts.anynonpass) != 0x01
15541559
end
15551560

1561+
# Recursive function that finds the column that the result counts
1562+
# can begin at by taking into account the width of the descriptions
1563+
# and the amount of indentation. If a test set had no failures, and
1564+
# no failures in child test sets, there is no need to include those
1565+
# in calculating the alignment
1566+
function get_alignment(ts::DefaultTestSet, depth::Int)
1567+
# The minimum width at this depth is
1568+
ts_width = 2*depth + length(ts.description)
1569+
# If not verbose and all passing, no need to look at children
1570+
!ts.verbose && !anynonpass(ts) && return ts_width
1571+
# Return the maximum of this width and the minimum width
1572+
# for all children (if they exist)
1573+
isempty(ts.results) && return ts_width
1574+
child_widths = map(t->get_alignment(t, depth+1), ts.results)
1575+
return max(ts_width, maximum(child_widths))
1576+
end
1577+
get_alignment(ts, depth::Int) = 0
1578+
15561579
"""
15571580
format_duration(::AbstractTestSet)
15581581
@@ -1564,7 +1587,7 @@ format_duration(::AbstractTestSet) = "?s"
15641587

15651588
function format_duration(ts::DefaultTestSet)
15661589
(; time_start, time_end) = ts
1567-
isnothing(time_end) && return ""
1590+
time_end === 0. && return ""
15681591

15691592
dur_s = time_end - time_start
15701593
if dur_s < 60

stdlib/Test/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ end
584584
@test total_error == 6
585585
@test total_broken == 0
586586
end
587-
ts.anynonpass = false
587+
@atomic ts.anynonpass = false
588588
deleteat!(Test.get_testset().results, 1)
589589
end
590590

test/runtests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,20 +422,20 @@ cd(@__DIR__) do
422422
=#
423423
Test.TESTSET_PRINT_ENABLE[] = false
424424
o_ts = Test.DefaultTestSet("Overall")
425-
o_ts.time_end = o_ts.time_start + o_ts_duration # manually populate the timing
425+
@atomic o_ts.time_end = o_ts.time_start + o_ts_duration # manually populate the timing
426426
BuildkiteTestJSON.write_testset_json_files(@__DIR__, o_ts)
427427
Test.push_testset(o_ts)
428428
completed_tests = Set{String}()
429429
for (testname, (resp,), duration) in results
430430
push!(completed_tests, testname)
431431
if isa(resp, Test.DefaultTestSet)
432-
resp.time_end = resp.time_start + duration
432+
@atomic resp.time_end = resp.time_start + duration
433433
Test.push_testset(resp)
434434
Test.record(o_ts, resp)
435435
Test.pop_testset()
436436
elseif isa(resp, Test.TestSetException)
437437
fake = Test.DefaultTestSet(testname)
438-
fake.time_end = fake.time_start + duration
438+
@atomic fake.time_end = fake.time_start + duration
439439
for i in 1:resp.pass
440440
Test.record(fake, Test.Pass(:test, nothing, nothing, nothing, LineNumberNode(@__LINE__, @__FILE__)))
441441
end
@@ -457,7 +457,7 @@ cd(@__DIR__) do
457457
# the test runner itself had some problem, so we may have hit a segfault,
458458
# deserialization errors or something similar. Record this testset as Errored.
459459
fake = Test.DefaultTestSet(testname)
460-
fake.time_end = fake.time_start + duration
460+
@atomic fake.time_end = fake.time_start + duration
461461
Test.record(fake, Test.Error(:nontest_error, testname, nothing, Base.ExceptionStack(NamedTuple[(;exception = resp, backtrace = [])]), LineNumberNode(1), nothing))
462462
Test.push_testset(fake)
463463
Test.record(o_ts, fake)
@@ -477,7 +477,7 @@ cd(@__DIR__) do
477477
println()
478478
# o_ts.verbose = true # set to true to show all timings when successful
479479
Test.print_test_results(o_ts, 1)
480-
if !o_ts.anynonpass
480+
if !Test.anynonpass(o_ts)
481481
printstyled(" SUCCESS\n"; bold=true, color=:green)
482482
else
483483
printstyled(" FAILURE\n\n"; bold=true, color=:red)

0 commit comments

Comments
 (0)