Skip to content

Update Nvptx backend for Zig 0.10 #12878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/std/builtin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ pub fn default_panic(msg: []const u8, error_return_trace: ?*StackTrace, ret_addr
// Didn't have boot_services, just fallback to whatever.
std.os.abort();
},
.cuda => std.os.abort(),
else => {
const first_trace_addr = ret_addr orelse @returnAddress();
std.debug.panicImpl(error_return_trace, first_trace_addr, msg);
Expand Down
6 changes: 6 additions & 0 deletions lib/std/os.zig
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,16 @@ pub fn abort() noreturn {
@breakpoint();
exit(1);
}
if (builtin.os.tag == .cuda) {
// TODO: introduce `@trap` instead of abusing https://github.com/ziglang/zig/issues/2291
@"llvm.trap"();
}

system.abort();
}

extern fn @"llvm.trap"() noreturn;

pub const RaiseError = UnexpectedError;

pub fn raise(sig: u8) RaiseError!void {
Expand Down
7 changes: 7 additions & 0 deletions lib/std/target.zig
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,13 @@ pub const Target = struct {
};
}

pub fn isNvptx(arch: Arch) bool {
return switch (arch) {
.nvptx, .nvptx64 => true,
else => false,
};
}

pub fn parseCpuModel(arch: Arch, cpu_name: []const u8) !*const Cpu.Model {
for (arch.allCpuModels()) |cpu| {
if (mem.eql(u8, cpu_name, cpu.name)) {
Expand Down
9 changes: 9 additions & 0 deletions src/Module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,15 @@ pub const Decl = struct {
var buffer = std.ArrayList(u8).init(mod.gpa);
defer buffer.deinit();
try decl.renderFullyQualifiedName(mod, buffer.writer());

// Sanitize the name for nvptx which is more restrictive.
if (mod.comp.bin_file.options.target.cpu.arch.isNvptx()) {
for (buffer.items) |*byte| switch (byte.*) {
'{', '}', '*', '[', ']', '(', ')', ',', ' ', '\'' => byte.* = '_',
else => {},
};
}
Comment on lines +725 to +730
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not gonna cause linker errors when two symbols map to the same value? Say functions @"A()" and @"A{}".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's true. I don't know how robust this should be. I could sanitize and append some hash. Would that be better ?


return buffer.toOwnedSliceSentinel(0);
}

Expand Down
13 changes: 6 additions & 7 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18202,12 +18202,6 @@ fn zirAddrSpaceCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.Inst
else
dest_ptr_ty;

if (try sema.resolveMaybeUndefVal(block, ptr_src, ptr)) |val| {
// Pointer value should compatible with both address spaces.
// TODO: Figure out why this generates an invalid bitcast.
return sema.addConstant(dest_ty, val);
}

try sema.requireRuntimeBlock(block, src, ptr_src);
// TODO: Address space cast safety?

Expand Down Expand Up @@ -21397,7 +21391,12 @@ fn validateExternType(
},
.Fn => {
if (position != .other) return false;
return !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention());
return switch (ty.fnCallingConvention()) {
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
// The goal is to experiment with more integrated CPU/GPU code.
.PtxKernel => true,
else => !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention()),
};
},
.Enum => {
var buf: Type.Payload.Bits = undefined;
Expand Down
32 changes: 17 additions & 15 deletions src/link/NvPtx.zig
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx {
if (!build_options.have_llvm) return error.PtxArchNotSupported;
if (!options.use_llvm) return error.PtxArchNotSupported;

switch (options.target.cpu.arch) {
.nvptx, .nvptx64 => {},
else => return error.PtxArchNotSupported,
}
if (!options.target.cpu.arch.isNvptx()) return error.PtxArchNotSupported;

switch (options.target.os.tag) {
// TODO: does it also work with nvcl ?
Expand Down Expand Up @@ -59,9 +56,8 @@ pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Option
if (!options.use_llvm) return error.PtxArchNotSupported;
assert(options.target.ofmt == .nvptx);

const nvptx = try createEmpty(allocator, options);
log.info("Opening .ptx target file {s}", .{sub_path});
return nvptx;
log.debug("Opening .ptx target file {s}", .{sub_path});
return createEmpty(allocator, options);
}

pub fn deinit(self: *NvPtx) void {
Expand Down Expand Up @@ -109,13 +105,19 @@ pub fn flushModule(self: *NvPtx, comp: *Compilation, prog_node: *std.Progress.No
const tracy = trace(@src());
defer tracy.end();

var hack_comp = comp;
if (comp.bin_file.options.emit) |emit| {
hack_comp.emit_asm = .{
.directory = emit.directory,
.basename = comp.bin_file.intermediary_basename.?,
};
hack_comp.bin_file.options.emit = null;
const outfile = comp.bin_file.options.emit.?;
// We modify 'comp' before passing it to LLVM, but restore value afterwards.
// We tell LLVM to not try to build a .o, only an "assembly" file.
// This is required by the LLVM PTX backend.
comp.bin_file.options.emit = null;
comp.emit_asm = .{
.directory = outfile.directory,
.basename = comp.bin_file.intermediary_basename.?,
};
defer {
comp.bin_file.options.emit = outfile;
comp.emit_asm = null;
}
return try self.llvm_object.flushModule(hack_comp, prog_node);

try self.llvm_object.flushModule(comp, prog_node);
}
8 changes: 6 additions & 2 deletions src/target.zig
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,11 @@ pub fn classifyCompilerRtLibName(target: std.Target, name: []const u8) CompilerR
}

pub fn hasDebugInfo(target: std.Target) bool {
_ = target;
if (target.cpu.arch.isNvptx()) {
// TODO: not sure how to test "ptx >= 7.5" with featureset
return std.Target.nvptx.featureSetHas(target.cpu.features, .ptx75);
}

return true;
}

Expand Down Expand Up @@ -651,7 +655,7 @@ pub fn addrSpaceCastIsValid(
const arch = target.cpu.arch;
switch (arch) {
.x86_64, .i386 => return arch.supportsAddressSpace(from) and arch.supportsAddressSpace(to),
.amdgcn => {
.nvptx64, .nvptx, .amdgcn => {
const to_generic = arch.supportsAddressSpace(from) and to == .generic;
const from_generic = arch.supportsAddressSpace(to) and from == .generic;
return to_generic or from_generic;
Expand Down
3 changes: 1 addition & 2 deletions test/cases.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ const TestContext = @import("../src/test.zig").TestContext;
pub fn addCases(ctx: *TestContext) !void {
try @import("compile_errors.zig").addCases(ctx);
try @import("stage2/cbe.zig").addCases(ctx);
// https://github.com/ziglang/zig/issues/10968
//try @import("stage2/nvptx.zig").addCases(ctx);
try @import("stage2/nvptx.zig").addCases(ctx);
}
43 changes: 38 additions & 5 deletions test/stage2/nvptx.zig
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ pub fn addCases(ctx: *TestContext) !void {
var case = addPtx(ctx, "nvptx: read special registers");

case.compiles(
\\fn threadIdX() usize {
\\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
\\ : [ret] "=r" (-> u32),
\\ );
\\ return @as(usize, tid);
\\fn threadIdX() u32 {
\\ return asm ("mov.u32 \t%[r], %tid.x;"
\\ : [r] "=r" (-> u32),
\\ );
\\}
\\
\\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
Expand All @@ -49,6 +48,38 @@ pub fn addCases(ctx: *TestContext) !void {
\\}
);
}

{
var case = addPtx(ctx, "nvptx: reduce in shared mem");
case.compiles(
\\fn threadIdX() u32 {
\\ return asm ("mov.u32 \t%[r], %tid.x;"
\\ : [r] "=r" (-> u32),
\\ );
\\}
\\
\\ var _sdata: [1024]f32 addrspace(.shared) = undefined;
\\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(.PtxKernel) void {
\\ var sdata = @addrSpaceCast(.generic, &_sdata);
\\ const tid: u32 = threadIdX();
\\ var sum = d_x[tid];
\\ sdata[tid] = sum;
\\ asm volatile ("bar.sync \t0;");
\\ var s: u32 = 512;
\\ while (s > 0) : (s = s >> 1) {
\\ if (tid < s) {
\\ sum += sdata[tid + s];
\\ sdata[tid] = sum;
\\ }
\\ asm volatile ("bar.sync \t0;");
\\ }
\\
\\ if (tid == 0) {
\\ out.* = sum;
\\ }
\\ }
);
}
}

const nvptx_target = std.zig.CrossTarget{
Expand All @@ -68,6 +99,8 @@ pub fn addPtx(
.files = std.ArrayList(TestContext.File).init(ctx.cases.allocator),
.link_libc = false,
.backend = .llvm,
// Bug in Debug mode
.optimize_mode = .ReleaseSafe,
}) catch @panic("out of memory");
return &ctx.cases.items[ctx.cases.items.len - 1];
}