Skip to content

rework fuzzing API to accept a function pointer parameter #21370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 77 additions & 32 deletions lib/compiler/test_runner.zig
Original file line number Diff line number Diff line change
Expand Up @@ -145,31 +145,23 @@ fn mainServer() !void {
.start_fuzzing => {
if (!builtin.fuzz) unreachable;
const index = try server.receiveBody_u32();
var first = true;
const test_fn = builtin.test_functions[index];
while (true) {
testing.allocator_instance = .{};
defer if (testing.allocator_instance.deinit() == .leak) std.process.exit(1);
log_err_count = 0;
is_fuzz_test = false;
test_fn.func() catch |err| switch (err) {
error.SkipZigTest => continue,
else => {
if (@errorReturnTrace()) |trace| {
std.debug.dumpStackTrace(trace.*);
}
std.debug.print("failed with error.{s}\n", .{@errorName(err)});
std.process.exit(1);
},
};
if (!is_fuzz_test) @panic("missed call to std.testing.fuzzInput");
if (log_err_count != 0) @panic("error logs detected");
if (first) {
first = false;
const entry_addr = @intFromPtr(test_fn.func);
try server.serveU64Message(.fuzz_start_addr, entry_addr);
}
}
const entry_addr = @intFromPtr(test_fn.func);
try server.serveU64Message(.fuzz_start_addr, entry_addr);
defer if (testing.allocator_instance.deinit() == .leak) std.process.exit(1);
is_fuzz_test = false;
test_fn.func() catch |err| switch (err) {
error.SkipZigTest => return,
else => {
if (@errorReturnTrace()) |trace| {
std.debug.dumpStackTrace(trace.*);
}
std.debug.print("failed with error.{s}\n", .{@errorName(err)});
std.process.exit(1);
},
};
if (!is_fuzz_test) @panic("missed call to std.testing.fuzz");
if (log_err_count != 0) @panic("error logs detected");
},

else => {
Expand Down Expand Up @@ -349,19 +341,72 @@ const FuzzerSlice = extern struct {

var is_fuzz_test: bool = undefined;

extern fn fuzzer_next() FuzzerSlice;
extern fn fuzzer_start(testOne: *const fn ([*]const u8, usize) callconv(.C) void) void;
extern fn fuzzer_init(cache_dir: FuzzerSlice) void;
extern fn fuzzer_coverage_id() u64;

pub fn fuzzInput(options: testing.FuzzInputOptions) []const u8 {
pub fn fuzz(
comptime testOne: fn ([]const u8) anyerror!void,
options: testing.FuzzInputOptions,
) anyerror!void {
// Prevent this function from confusing the fuzzer by omitting its own code
// coverage from being considered.
@disableInstrumentation();
if (crippled) return "";

// Some compiler backends are not capable of handling fuzz testing yet but
// we still want CI test coverage enabled.
if (crippled) return;

// Smoke test to ensure the test did not use conditional compilation to
// contradict itself by making it not actually be a fuzz test when the test
// is built in fuzz mode.
is_fuzz_test = true;

// Ensure no test failure occurred before starting fuzzing.
if (log_err_count != 0) @panic("error logs detected");

// libfuzzer is in a separate compilation unit so that its own code can be
// excluded from code coverage instrumentation. It needs a function pointer
// it can call for checking exactly one input. Inside this function we do
// our standard unit test checks such as memory leaks, and interaction with
// error logs.
const global = struct {
fn fuzzer_one(input_ptr: [*]const u8, input_len: usize) callconv(.C) void {
@disableInstrumentation();
testing.allocator_instance = .{};
defer if (testing.allocator_instance.deinit() == .leak) std.process.exit(1);
log_err_count = 0;
testOne(input_ptr[0..input_len]) catch |err| switch (err) {
error.SkipZigTest => return,
else => {
std.debug.lockStdErr();
if (@errorReturnTrace()) |trace| std.debug.dumpStackTrace(trace.*);
std.debug.print("failed with error.{s}\n", .{@errorName(err)});
std.process.exit(1);
},
};
if (log_err_count != 0) {
std.debug.lockStdErr();
std.debug.print("error logs detected\n", .{});
std.process.exit(1);
}
}
};
if (builtin.fuzz) {
return fuzzer_next().toSlice();
const prev_allocator_state = testing.allocator_instance;
testing.allocator_instance = .{};
fuzzer_start(&global.fuzzer_one);
testing.allocator_instance = prev_allocator_state;
return;
}
if (options.corpus.len == 0) return "";
var prng = std.Random.DefaultPrng.init(testing.random_seed);
const random = prng.random();
return options.corpus[random.uintLessThan(usize, options.corpus.len)];

// When the unit test executable is not built in fuzz mode, only run the
// provided corpus.
for (options.corpus) |input| {
try testOne(input);
}

// In case there is no provided corpus, also use an empty
// string as a smoke test.
try testOne("");
}
89 changes: 48 additions & 41 deletions lib/fuzzer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ fn logOverride(
f.writer().print(prefix1 ++ prefix2 ++ format ++ "\n", args) catch @panic("failed to write to fuzzer log");
}

export threadlocal var __sancov_lowest_stack: usize = std.math.maxInt(usize);
/// Helps determine run uniqueness in the face of recursion.
export threadlocal var __sancov_lowest_stack: usize = 0;

export fn __sanitizer_cov_trace_const_cmp1(arg1: u8, arg2: u8) void {
handleCmp(@returnAddress(), arg1, arg2);
Expand Down Expand Up @@ -220,7 +221,6 @@ const Fuzzer = struct {
.n_runs = 0,
.unique_runs = 0,
.pcs_len = pcs.len,
.lowest_stack = std.math.maxInt(usize),
};
f.seen_pcs.appendSliceAssumeCapacity(std.mem.asBytes(&header));
f.seen_pcs.appendNTimesAssumeCapacity(0, n_bitset_elems * @sizeOf(usize));
Expand All @@ -235,22 +235,41 @@ const Fuzzer = struct {
};
}

fn next(f: *Fuzzer) ![]const u8 {
fn start(f: *Fuzzer) !void {
const gpa = f.gpa;
const rng = fuzzer.rng.random();

if (f.recent_cases.entries.len == 0) {
// Prepare initial input.
try f.recent_cases.ensureUnusedCapacity(gpa, 100);
const len = rng.uintLessThanBiased(usize, 80);
try f.input.resize(gpa, len);
rng.bytes(f.input.items);
f.recent_cases.putAssumeCapacity(.{
.id = 0,
.input = try gpa.dupe(u8, f.input.items),
.score = 0,
}, {});
} else {
// Prepare initial input.
assert(f.recent_cases.entries.len == 0);
assert(f.n_runs == 0);
try f.recent_cases.ensureUnusedCapacity(gpa, 100);
const len = rng.uintLessThanBiased(usize, 80);
try f.input.resize(gpa, len);
rng.bytes(f.input.items);
f.recent_cases.putAssumeCapacity(.{
.id = 0,
.input = try gpa.dupe(u8, f.input.items),
.score = 0,
}, {});

const header: *volatile SeenPcsHeader = @ptrCast(f.seen_pcs.items[0..@sizeOf(SeenPcsHeader)]);

while (true) {
const chosen_index = rng.uintLessThanBiased(usize, f.recent_cases.entries.len);
const run = &f.recent_cases.keys()[chosen_index];
f.input.clearRetainingCapacity();
f.input.appendSliceAssumeCapacity(run.input);
try f.mutate();

@memset(f.pc_counters, 0);
__sancov_lowest_stack = std.math.maxInt(usize);
f.coverage.reset();

fuzzer_one(f.input.items.ptr, f.input.items.len);

f.n_runs += 1;
_ = @atomicRmw(usize, &header.n_runs, .Add, 1, .monotonic);

if (f.n_runs % 10000 == 0) f.dumpStats();

const analysis = f.analyzeLastRun();
Expand Down Expand Up @@ -301,7 +320,6 @@ const Fuzzer = struct {
}
}

const header: *volatile SeenPcsHeader = @ptrCast(f.seen_pcs.items[0..@sizeOf(SeenPcsHeader)]);
_ = @atomicRmw(usize, &header.unique_runs, .Add, 1, .monotonic);
}

Expand All @@ -317,26 +335,12 @@ const Fuzzer = struct {
// This has to be done before deinitializing the deleted items.
const doomed_runs = f.recent_cases.keys()[cap..];
f.recent_cases.shrinkRetainingCapacity(cap);
for (doomed_runs) |*run| {
std.log.info("culling score={d} id={d}", .{ run.score, run.id });
run.deinit(gpa);
for (doomed_runs) |*doomed_run| {
std.log.info("culling score={d} id={d}", .{ doomed_run.score, doomed_run.id });
doomed_run.deinit(gpa);
}
}
}

const chosen_index = rng.uintLessThanBiased(usize, f.recent_cases.entries.len);
const run = &f.recent_cases.keys()[chosen_index];
f.input.clearRetainingCapacity();
f.input.appendSliceAssumeCapacity(run.input);
try f.mutate();

f.n_runs += 1;
const header: *volatile SeenPcsHeader = @ptrCast(f.seen_pcs.items[0..@sizeOf(SeenPcsHeader)]);
_ = @atomicRmw(usize, &header.n_runs, .Add, 1, .monotonic);
_ = @atomicRmw(usize, &header.lowest_stack, .Min, __sancov_lowest_stack, .monotonic);
@memset(f.pc_counters, 0);
f.coverage.reset();
return f.input.items;
}

fn visitPc(f: *Fuzzer, pc: usize) void {
Expand Down Expand Up @@ -419,10 +423,13 @@ export fn fuzzer_coverage_id() u64 {
return fuzzer.coverage_id;
}

export fn fuzzer_next() Fuzzer.Slice {
return Fuzzer.Slice.fromZig(fuzzer.next() catch |err| switch (err) {
error.OutOfMemory => @panic("out of memory"),
});
var fuzzer_one: *const fn (input_ptr: [*]const u8, input_len: usize) callconv(.C) void = undefined;

export fn fuzzer_start(testOne: @TypeOf(fuzzer_one)) void {
fuzzer_one = testOne;
fuzzer.start() catch |err| switch (err) {
error.OutOfMemory => fatal("out of memory", .{}),
};
}

export fn fuzzer_init(cache_dir_struct: Fuzzer.Slice) void {
Expand All @@ -432,24 +439,24 @@ export fn fuzzer_init(cache_dir_struct: Fuzzer.Slice) void {
const pc_counters_start = @extern([*]u8, .{
.name = "__start___sancov_cntrs",
.linkage = .weak,
}) orelse fatal("missing __start___sancov_cntrs symbol");
}) orelse fatal("missing __start___sancov_cntrs symbol", .{});

const pc_counters_end = @extern([*]u8, .{
.name = "__stop___sancov_cntrs",
.linkage = .weak,
}) orelse fatal("missing __stop___sancov_cntrs symbol");
}) orelse fatal("missing __stop___sancov_cntrs symbol", .{});

const pc_counters = pc_counters_start[0 .. pc_counters_end - pc_counters_start];

const pcs_start = @extern([*]usize, .{
.name = "__start___sancov_pcs1",
.linkage = .weak,
}) orelse fatal("missing __start___sancov_pcs1 symbol");
}) orelse fatal("missing __start___sancov_pcs1 symbol", .{});

const pcs_end = @extern([*]usize, .{
.name = "__stop___sancov_pcs1",
.linkage = .weak,
}) orelse fatal("missing __stop___sancov_pcs1 symbol");
}) orelse fatal("missing __stop___sancov_pcs1 symbol", .{});

const pcs = pcs_start[0 .. pcs_end - pcs_start];

Expand Down
2 changes: 1 addition & 1 deletion lib/fuzzer/index.html → lib/fuzzer/web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@
<ul>
<li>Total Runs: <span id="statTotalRuns"></span></li>
<li>Unique Runs: <span id="statUniqueRuns"></span></li>
<li>Speed (Runs/Second): <span id="statSpeed"></span></li>
<li>Coverage: <span id="statCoverage"></span></li>
<li>Lowest Stack: <span id="statLowestStack"></span></li>
<li>Entry Points: <ul id="entryPointsList"></ul></li>
</ul>
</div>
Expand Down
7 changes: 5 additions & 2 deletions lib/fuzzer/main.js → lib/fuzzer/web/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
const domSourceText = document.getElementById("sourceText");
const domStatTotalRuns = document.getElementById("statTotalRuns");
const domStatUniqueRuns = document.getElementById("statUniqueRuns");
const domStatSpeed = document.getElementById("statSpeed");
const domStatCoverage = document.getElementById("statCoverage");
const domStatLowestStack = document.getElementById("statLowestStack");
const domEntryPointsList = document.getElementById("entryPointsList");

let wasm_promise = fetch("main.wasm");
Expand All @@ -32,6 +32,9 @@
const msg = decodeString(ptr, len);
throw new Error("panic: " + msg);
},
timestamp: function () {
return BigInt(new Date());
},
emitSourceIndexChange: onSourceIndexChange,
emitCoverageUpdate: onCoverageUpdate,
emitEntryPointsUpdate: renderStats,
Expand Down Expand Up @@ -158,7 +161,7 @@
domStatTotalRuns.innerText = totalRuns;
domStatUniqueRuns.innerText = uniqueRuns + " (" + percent(uniqueRuns, totalRuns) + "%)";
domStatCoverage.innerText = coveredSourceLocations + " / " + totalSourceLocations + " (" + percent(coveredSourceLocations, totalSourceLocations) + "%)";
domStatLowestStack.innerText = unwrapString(wasm_exports.lowestStack());
domStatSpeed.innerText = wasm_exports.totalRunsPerSecond().toFixed(0);

const entryPoints = unwrapInt32Array(wasm_exports.entryPoints());
resizeDomList(domEntryPointsList, entryPoints.length, "<li></li>");
Expand Down
Loading