Skip to content

Commit feab1eb

Browse files
authored
Merge pull request #12878 from gwenzek/ptx
Update Nvptx backend for Zig 0.10
2 parents 65f860b + c289794 commit feab1eb

File tree

9 files changed

+91
-31
lines changed

9 files changed

+91
-31
lines changed

lib/std/builtin.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ pub fn default_panic(msg: []const u8, error_return_trace: ?*StackTrace, ret_addr
833833
// Didn't have boot_services, just fallback to whatever.
834834
std.os.abort();
835835
},
836+
.cuda => std.os.abort(),
836837
else => {
837838
const first_trace_addr = ret_addr orelse @returnAddress();
838839
std.debug.panicImpl(error_return_trace, first_trace_addr, msg);

lib/std/os.zig

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,16 @@ pub fn abort() noreturn {
500500
@breakpoint();
501501
exit(1);
502502
}
503+
if (builtin.os.tag == .cuda) {
504+
// TODO: introduce `@trap` instead of abusing https://github.com/ziglang/zig/issues/2291
505+
@"llvm.trap"();
506+
}
503507

504508
system.abort();
505509
}
506510

511+
extern fn @"llvm.trap"() noreturn;
512+
507513
pub const RaiseError = UnexpectedError;
508514

509515
pub fn raise(sig: u8) RaiseError!void {

lib/std/target.zig

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,13 @@ pub const Target = struct {
951951
};
952952
}
953953

954+
pub fn isNvptx(arch: Arch) bool {
955+
return switch (arch) {
956+
.nvptx, .nvptx64 => true,
957+
else => false,
958+
};
959+
}
960+
954961
pub fn parseCpuModel(arch: Arch, cpu_name: []const u8) !*const Cpu.Model {
955962
for (arch.allCpuModels()) |cpu| {
956963
if (mem.eql(u8, cpu_name, cpu.name)) {

src/Module.zig

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,15 @@ pub const Decl = struct {
720720
var buffer = std.ArrayList(u8).init(mod.gpa);
721721
defer buffer.deinit();
722722
try decl.renderFullyQualifiedName(mod, buffer.writer());
723+
724+
// Sanitize the name for nvptx which is more restrictive.
725+
if (mod.comp.bin_file.options.target.cpu.arch.isNvptx()) {
726+
for (buffer.items) |*byte| switch (byte.*) {
727+
'{', '}', '*', '[', ']', '(', ')', ',', ' ', '\'' => byte.* = '_',
728+
else => {},
729+
};
730+
}
731+
723732
return buffer.toOwnedSliceSentinel(0);
724733
}
725734

src/Sema.zig

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18202,12 +18202,6 @@ fn zirAddrSpaceCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.Inst
1820218202
else
1820318203
dest_ptr_ty;
1820418204

18205-
if (try sema.resolveMaybeUndefVal(block, ptr_src, ptr)) |val| {
18206-
// Pointer value should compatible with both address spaces.
18207-
// TODO: Figure out why this generates an invalid bitcast.
18208-
return sema.addConstant(dest_ty, val);
18209-
}
18210-
1821118205
try sema.requireRuntimeBlock(block, src, ptr_src);
1821218206
// TODO: Address space cast safety?
1821318207

@@ -21397,7 +21391,12 @@ fn validateExternType(
2139721391
},
2139821392
.Fn => {
2139921393
if (position != .other) return false;
21400-
return !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention());
21394+
return switch (ty.fnCallingConvention()) {
21395+
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
21396+
// The goal is to experiment with more integrated CPU/GPU code.
21397+
.PtxKernel => true,
21398+
else => !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention()),
21399+
};
2140121400
},
2140221401
.Enum => {
2140321402
var buf: Type.Payload.Bits = undefined;

src/link/NvPtx.zig

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ pub fn createEmpty(gpa: Allocator, options: link.Options) !*NvPtx {
2828
if (!build_options.have_llvm) return error.PtxArchNotSupported;
2929
if (!options.use_llvm) return error.PtxArchNotSupported;
3030

31-
switch (options.target.cpu.arch) {
32-
.nvptx, .nvptx64 => {},
33-
else => return error.PtxArchNotSupported,
34-
}
31+
if (!options.target.cpu.arch.isNvptx()) return error.PtxArchNotSupported;
3532

3633
switch (options.target.os.tag) {
3734
// TODO: does it also work with nvcl ?
@@ -59,9 +56,8 @@ pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Option
5956
if (!options.use_llvm) return error.PtxArchNotSupported;
6057
assert(options.target.ofmt == .nvptx);
6158

62-
const nvptx = try createEmpty(allocator, options);
63-
log.info("Opening .ptx target file {s}", .{sub_path});
64-
return nvptx;
59+
log.debug("Opening .ptx target file {s}", .{sub_path});
60+
return createEmpty(allocator, options);
6561
}
6662

6763
pub fn deinit(self: *NvPtx) void {
@@ -109,13 +105,19 @@ pub fn flushModule(self: *NvPtx, comp: *Compilation, prog_node: *std.Progress.No
109105
const tracy = trace(@src());
110106
defer tracy.end();
111107

112-
var hack_comp = comp;
113-
if (comp.bin_file.options.emit) |emit| {
114-
hack_comp.emit_asm = .{
115-
.directory = emit.directory,
116-
.basename = comp.bin_file.intermediary_basename.?,
117-
};
118-
hack_comp.bin_file.options.emit = null;
108+
const outfile = comp.bin_file.options.emit.?;
109+
// We modify 'comp' before passing it to LLVM, but restore value afterwards.
110+
// We tell LLVM to not try to build a .o, only an "assembly" file.
111+
// This is required by the LLVM PTX backend.
112+
comp.bin_file.options.emit = null;
113+
comp.emit_asm = .{
114+
.directory = outfile.directory,
115+
.basename = comp.bin_file.intermediary_basename.?,
116+
};
117+
defer {
118+
comp.bin_file.options.emit = outfile;
119+
comp.emit_asm = null;
119120
}
120-
return try self.llvm_object.flushModule(hack_comp, prog_node);
121+
122+
try self.llvm_object.flushModule(comp, prog_node);
121123
}

src/target.zig

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,11 @@ pub fn classifyCompilerRtLibName(target: std.Target, name: []const u8) CompilerR
411411
}
412412

413413
pub fn hasDebugInfo(target: std.Target) bool {
414-
_ = target;
414+
if (target.cpu.arch.isNvptx()) {
415+
// TODO: not sure how to test "ptx >= 7.5" with featureset
416+
return std.Target.nvptx.featureSetHas(target.cpu.features, .ptx75);
417+
}
418+
415419
return true;
416420
}
417421

@@ -651,7 +655,7 @@ pub fn addrSpaceCastIsValid(
651655
const arch = target.cpu.arch;
652656
switch (arch) {
653657
.x86_64, .i386 => return arch.supportsAddressSpace(from) and arch.supportsAddressSpace(to),
654-
.amdgcn => {
658+
.nvptx64, .nvptx, .amdgcn => {
655659
const to_generic = arch.supportsAddressSpace(from) and to == .generic;
656660
const from_generic = arch.supportsAddressSpace(to) and from == .generic;
657661
return to_generic or from_generic;

test/cases.zig

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ const TestContext = @import("../src/test.zig").TestContext;
44
pub fn addCases(ctx: *TestContext) !void {
55
try @import("compile_errors.zig").addCases(ctx);
66
try @import("stage2/cbe.zig").addCases(ctx);
7-
// https://github.com/ziglang/zig/issues/10968
8-
//try @import("stage2/nvptx.zig").addCases(ctx);
7+
try @import("stage2/nvptx.zig").addCases(ctx);
98
}

test/stage2/nvptx.zig

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ pub fn addCases(ctx: *TestContext) !void {
2323
var case = addPtx(ctx, "nvptx: read special registers");
2424

2525
case.compiles(
26-
\\fn threadIdX() usize {
27-
\\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
28-
\\ : [ret] "=r" (-> u32),
29-
\\ );
30-
\\ return @as(usize, tid);
26+
\\fn threadIdX() u32 {
27+
\\ return asm ("mov.u32 \t%[r], %tid.x;"
28+
\\ : [r] "=r" (-> u32),
29+
\\ );
3130
\\}
3231
\\
3332
\\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
@@ -49,6 +48,38 @@ pub fn addCases(ctx: *TestContext) !void {
4948
\\}
5049
);
5150
}
51+
52+
{
53+
var case = addPtx(ctx, "nvptx: reduce in shared mem");
54+
case.compiles(
55+
\\fn threadIdX() u32 {
56+
\\ return asm ("mov.u32 \t%[r], %tid.x;"
57+
\\ : [r] "=r" (-> u32),
58+
\\ );
59+
\\}
60+
\\
61+
\\ var _sdata: [1024]f32 addrspace(.shared) = undefined;
62+
\\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(.PtxKernel) void {
63+
\\ var sdata = @addrSpaceCast(.generic, &_sdata);
64+
\\ const tid: u32 = threadIdX();
65+
\\ var sum = d_x[tid];
66+
\\ sdata[tid] = sum;
67+
\\ asm volatile ("bar.sync \t0;");
68+
\\ var s: u32 = 512;
69+
\\ while (s > 0) : (s = s >> 1) {
70+
\\ if (tid < s) {
71+
\\ sum += sdata[tid + s];
72+
\\ sdata[tid] = sum;
73+
\\ }
74+
\\ asm volatile ("bar.sync \t0;");
75+
\\ }
76+
\\
77+
\\ if (tid == 0) {
78+
\\ out.* = sum;
79+
\\ }
80+
\\ }
81+
);
82+
}
5283
}
5384

5485
const nvptx_target = std.zig.CrossTarget{
@@ -68,6 +99,8 @@ pub fn addPtx(
6899
.files = std.ArrayList(TestContext.File).init(ctx.cases.allocator),
69100
.link_libc = false,
70101
.backend = .llvm,
102+
// Bug in Debug mode
103+
.optimize_mode = .ReleaseSafe,
71104
}) catch @panic("out of memory");
72105
return &ctx.cases.items[ctx.cases.items.len - 1];
73106
}

0 commit comments

Comments
 (0)