Skip to content

Emit check for memory intrinsics for WebAssembly #16345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions src/codegen/llvm.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8511,6 +8511,14 @@ pub const FuncGen = struct {
const dest_ptr = self.sliceOrArrayPtr(dest_slice, ptr_ty);
const is_volatile = ptr_ty.isVolatilePtr(mod);

// Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless
// of the length. This means we need to emit a check where we skip the memset when the length
// is 0 as we allow for undefined pointers in 0-sized slices.
// This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
const intrinsic_len0_traps = o.target.isWasm() and
ptr_ty.isSlice(mod) and
std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory);

if (try self.air.value(bin_op.rhs, mod)) |elem_val| {
if (elem_val.isUndefDeep(mod)) {
// Even if safety is disabled, we still emit a memset to undefined since it conveys
Expand All @@ -8521,7 +8529,11 @@ pub const FuncGen = struct {
else
u8_llvm_ty.getUndef();
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
if (intrinsic_len0_traps) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}

if (safety and mod.comp.bin_file.options.valgrind) {
self.valgrindMarkUndef(dest_ptr, len);
Expand All @@ -8539,7 +8551,12 @@ pub const FuncGen = struct {
.val = byte_val,
});
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);

if (intrinsic_len0_traps) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}
return null;
}
}
Expand All @@ -8551,7 +8568,12 @@ pub const FuncGen = struct {
// In this case we can take advantage of LLVM's intrinsic.
const fill_byte = try self.bitCast(value, elem_ty, Type.u8);
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);

if (intrinsic_len0_traps) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}
return null;
}

Expand Down Expand Up @@ -8622,6 +8644,25 @@ pub const FuncGen = struct {
return null;
}

fn safeWasmMemset(
self: *FuncGen,
dest_ptr: *llvm.Value,
fill_byte: *llvm.Value,
len: *llvm.Value,
dest_ptr_align: u32,
is_volatile: bool,
) !void {
const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
const memset_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapSkip");
const end_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapEnd");
_ = self.builder.buildCondBr(cond, memset_block, end_block);
self.builder.positionBuilderAtEnd(memset_block);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
_ = self.builder.buildBr(end_block);
self.builder.positionBuilderAtEnd(end_block);
}

fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
const o = self.dg.object;
const mod = o.module;
Expand All @@ -8634,6 +8675,35 @@ pub const FuncGen = struct {
const len = self.sliceOrArrayLenInBytes(dest_slice, dest_ptr_ty);
const dest_ptr = self.sliceOrArrayPtr(dest_slice, dest_ptr_ty);
const is_volatile = src_ptr_ty.isVolatilePtr(mod) or dest_ptr_ty.isVolatilePtr(mod);

// When bulk-memory is enabled, this will be lowered to WebAssembly's memory.copy instruction.
// This instruction will trap on an invalid address, regardless of the length.
// For this reason we must add a check for 0-sized slices as its pointer field can be undefined.
// We only have to do this for slices as arrays will have a valid pointer.
// This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
if (o.target.isWasm() and
std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory) and
dest_ptr_ty.isSlice(mod))
{
const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
const memcpy_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapSkip");
const end_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapEnd");
_ = self.builder.buildCondBr(cond, memcpy_block, end_block);
self.builder.positionBuilderAtEnd(memcpy_block);
_ = self.builder.buildMemCpy(
dest_ptr,
dest_ptr_ty.ptrAlignment(mod),
src_ptr,
src_ptr_ty.ptrAlignment(mod),
len,
is_volatile,
);
_ = self.builder.buildBr(end_block);
self.builder.positionBuilderAtEnd(end_block);
return null;
}

_ = self.builder.buildMemCpy(
dest_ptr,
dest_ptr_ty.ptrAlignment(mod),
Expand Down
4 changes: 4 additions & 0 deletions test/standalone.zig
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ pub const build_cases = [_]BuildCase{
.build_root = "test/standalone/cmakedefine",
.import = @import("standalone/cmakedefine/build.zig"),
},
.{
.build_root = "test/standalone/zerolength_check",
.import = @import("standalone/zerolength_check/build.zig"),
},
};

const std = @import("std");
27 changes: 27 additions & 0 deletions test/standalone/zerolength_check/build.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
const std = @import("std");

pub fn build(b: *std.Build) void {
const test_step = b.step("test", "Test it");
b.default_step = test_step;

add(b, test_step, .Debug);
add(b, test_step, .ReleaseFast);
add(b, test_step, .ReleaseSmall);
add(b, test_step, .ReleaseSafe);
}

fn add(b: *std.Build, test_step: *std.Build.Step, optimize: std.builtin.OptimizeMode) void {
const unit_tests = b.addTest(.{
.root_source_file = .{ .path = "src/main.zig" },
.target = .{
.os_tag = .wasi,
.cpu_arch = .wasm32,
.cpu_features_add = std.Target.wasm.featureSet(&.{.bulk_memory}),
},
.optimize = optimize,
});

const run_unit_tests = b.addRunArtifact(unit_tests);
run_unit_tests.skip_foreign_checks = true;
test_step.dependOn(&run_unit_tests.step);
}
23 changes: 23 additions & 0 deletions test/standalone/zerolength_check/src/main.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
const std = @import("std");

test {
var dest = foo();
var source = foo();

@memcpy(dest, source);
@memset(dest, 4);
@memset(dest, undefined);

var dest2 = foo2();
@memset(dest2, 0);
}

fn foo() []u8 {
const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
return @as([*]align(1) u8, @ptrFromInt(ptr))[0..0];
}

fn foo2() []u64 {
const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
return @as([*]align(1) u64, @ptrFromInt(ptr))[0..0];
}