From 8c837a4f5aff3963f58dcfdfb73c090faab69bc4 Mon Sep 17 00:00:00 2001 From: Travis Staloch Date: Tue, 8 Nov 2022 00:32:59 -0800 Subject: [PATCH] @mulCarryless() - add a builtin to llvm backend addresses #9631. only works with llvm backend/x86 so far. allows new test/behavior/mul_carryless.zig to pass with -Denable-llvm. doesn't do any backend/arch/cpu-feature testing. --- src/Air.zig | 17 ++++++++- src/AstGen.zig | 10 +++++- src/BuiltinFn.zig | 8 +++++ src/Liveness.zig | 12 +++++++ src/Sema.zig | 64 +++++++++++++++++++++++++++++++++ src/Zir.zig | 11 ++++++ src/arch/aarch64/CodeGen.zig | 2 ++ src/arch/arm/CodeGen.zig | 2 ++ src/arch/riscv64/CodeGen.zig | 2 ++ src/arch/sparc64/CodeGen.zig | 2 ++ src/arch/wasm/CodeGen.zig | 2 ++ src/arch/x86_64/CodeGen.zig | 2 ++ src/codegen/c.zig | 2 ++ src/codegen/llvm.zig | 13 +++++++ src/codegen/llvm/bindings.zig | 9 +++++ src/print_air.zig | 15 ++++++++ src/print_zir.zig | 14 ++++++++ src/zig_llvm.cpp | 8 +++++ src/zig_llvm.h | 1 + test/behavior/mul_carryless.zig | 18 ++++++++++ 20 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 test/behavior/mul_carryless.zig diff --git a/src/Air.zig b/src/Air.zig index 3bcbdb8e98ab..8d09746b1742 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -178,6 +178,11 @@ pub const Inst = struct { /// and if an overflow happens, ov is 1. Otherwise ov is 0. /// Uses the `ty_pl` field. Payload is `Bin`. shl_with_overflow, + /// Carryless multiplication. Both operands are guaranteed to be the same type, + /// Result type is the same as both operands. + /// Uses the `ty_pl` field. Payload is `MulCarryless`. + /// Uses the `ty` field. + mul_carryless, /// Allocates stack local memory. /// Uses the `ty` field. alloc, @@ -897,6 +902,12 @@ pub const Shuffle = struct { mask_len: u32, }; +pub const MulCarryless = struct { + a: Inst.Ref, + b: Inst.Ref, + imm: Inst.Ref, +}; + pub const VectorCmp = struct { lhs: Inst.Ref, rhs: Inst.Ref, @@ -1222,7 +1233,11 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { const extra = air.extraData(Air.Bin, datas[inst].pl_op.payload).data; return air.typeOf(extra.lhs); }, - + .mul_carryless => { + // const extra = air.extraData(Air.MulCarryless, datas[inst].ty_pl.payload).data; + // return air.typeOf(extra.a); + return air.getRefType(datas[inst].ty_pl.ty); + }, .@"try" => { const err_union_ty = air.typeOf(datas[inst].pl_op.operand); return err_union_ty.errorUnionPayload(); diff --git a/src/AstGen.zig b/src/AstGen.zig index 021990883ac0..1d4d0fd760af 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -8250,7 +8250,15 @@ fn builtinCall( }); return rvalue(gz, ri, result, node); }, - + .mul_carryless => { + const result = try gz.addExtendedPayload(.mul_carryless, Zir.Inst.MulCarryless{ + .node = gz.nodeIndexToRelative(node), + .a = try expr(gz, scope, .{ .rl = .none }, params[0]), + .b = try expr(gz, scope, .{ .rl = .none }, params[1]), + .imm = try expr(gz, scope, .{ .rl = .none }, params[2]), + }); + return rvalue(gz, ri, result, node); + }, .atomic_load => { const result = try gz.addPlNode(.atomic_load, node, Zir.Inst.AtomicLoad{ // zig fmt: off diff --git a/src/BuiltinFn.zig b/src/BuiltinFn.zig index 24625dc10a16..96f66861ff5e 100644 --- a/src/BuiltinFn.zig +++ b/src/BuiltinFn.zig @@ -66,6 +66,7 @@ pub const Tag = enum { wasm_memory_grow, mod, mul_with_overflow, + mul_carryless, panic, pop_count, prefetch, @@ -611,6 +612,13 @@ pub const list = list: { .param_count = 4, }, }, + .{ + "@mulCarryless", + .{ + .tag = .mul_carryless, + .param_count = 3, + }, + }, .{ "@panic", .{ diff --git a/src/Liveness.zig b/src/Liveness.zig index a1a7e6b2154e..86b074193e28 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -405,6 +405,13 @@ pub fn categorizeOperand( if (extra.b == operand_ref) return matchOperandSmallIndex(l, inst, 1, .none); return .none; }, + .mul_carryless => { + const extra = air.extraData(Air.MulCarryless, air_datas[inst].ty_pl.payload).data; + if (extra.a == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none); + if (extra.b == operand_ref) return matchOperandSmallIndex(l, inst, 1, .none); + if (extra.imm == operand_ref) return matchOperandSmallIndex(l, inst, 2, .none); + return .none; + }, .reduce, .reduce_optimized => { const reduce = air_datas[inst].reduce; if (reduce.operand == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none); @@ -905,6 +912,11 @@ fn analyzeInst( const extra = a.air.extraData(Air.Bin, ty_pl.payload).data; return trackOperands(a, new_set, inst, main_tomb, .{ extra.lhs, extra.rhs, .none }); }, + .mul_carryless => { + const ty_pl = inst_datas[inst].ty_pl; + const extra = a.air.extraData(Air.MulCarryless, ty_pl.payload).data; + return trackOperands(a, new_set, inst, main_tomb, .{ extra.a, extra.b, extra.imm }); + }, .dbg_var_ptr, .dbg_var_val, diff --git a/src/Sema.zig b/src/Sema.zig index 3bad0b2b5f6b..888d2db22236 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1132,6 +1132,7 @@ fn analyzeBodyInner( .sub_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode), .mul_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode), .shl_with_overflow => try sema.zirOverflowArithmetic(block, extended, extended.opcode), + .mul_carryless => try sema.zirMulCarryless( block, extended), .c_undef => try sema.zirCUndef( block, extended), .c_include => try sema.zirCInclude( block, extended), .c_define => try sema.zirCDefine( block, extended), @@ -13606,6 +13607,69 @@ fn zirRem(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins return block.addBinOp(air_tag, casted_lhs, casted_rhs); } +fn zirMulCarryless( + sema: *Sema, + block: *Block, + extended: Zir.Inst.Extended.InstData, +) CompileError!Air.Inst.Ref { + const tracy = trace(@src()); + defer tracy.end(); + + const extra = sema.code.extraData(Zir.Inst.MulCarryless, extended.operand).data; + const src = LazySrcLoc.nodeOffset(extra.node); + sema.src = src; + const a_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node }; + const b_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = extra.node }; + const imm_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = extra.node }; + + const a = try sema.resolveInst(extra.a); + const b = try sema.resolveInst(extra.b); + const imm = try sema.resolveInst(extra.imm); + + const a_ty = sema.typeOf(a); + const b_ty = sema.typeOf(b); + const imm_ty = sema.typeOf(imm); + const mod = sema.mod; + const target = mod.getTarget(); + + try sema.checkVectorizableBinaryOperands(block, src, a_ty, b_ty, a_src, b_src); + + if (a_ty.scalarType().zigTypeTag() != .Int) + return sema.fail(block, a_src, "expected vector of 64 bit integers, found '{}'", .{a_ty.fmt(mod)}); + const a_bit_size = try a_ty.scalarType().bitSizeAdvanced(target, sema.kit(block, src)); + if (a_bit_size != 64) + return sema.fail(block, a_src, "expected vector of 64 bit integers, found '{}'", .{a_ty.fmt(mod)}); + + if (!imm_ty.isInt()) + return sema.fail(block, imm_src, "imm must be a comptime known 8 bit integer", .{}); + const bit_size = try imm_ty.bitSizeAdvanced(target, sema.kit(block, src)); + if (bit_size != 8) + return sema.fail(block, imm_src, "imm must be a comptime known 8 bit integer", .{}); + if (!try sema.isComptimeKnown(block, imm_src, imm)) + return sema.fail(block, imm_src, "imm must be a comptime known 8 bit integer", .{}); + + const instructions = &[_]Air.Inst.Ref{ a, b }; + const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{ + .override = &[_]LazySrcLoc{ a_src, b_src }, + }); + const casted_a = try sema.coerce(block, resolved_type, a, a_src); + const casted_b = try sema.coerce(block, resolved_type, b, b_src); + + return block.addInst(.{ + .tag = .mul_carryless, + .data = .{ + .ty_pl = .{ + .ty = try block.sema.addType(sema.typeOf(casted_a)), + .payload = try block.sema.addExtra(Air.MulCarryless{ + .a = casted_a, + .b = casted_b, + .imm = imm, + }), + }, + }, + }); +} + fn zirOverflowArithmetic( sema: *Sema, block: *Block, diff --git a/src/Zir.zig b/src/Zir.zig index ed425ea73e1e..1e735f611e05 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -1927,6 +1927,10 @@ pub const Inst = struct { /// `operand` is payload index to `OverflowArithmetic`. /// `small` is unused. mul_with_overflow, + /// Implements the `@mulCarryless` builtin. + /// `operand` is payload index to `MulCarryless`. + /// `small` is unused. + mul_carryless, /// Implements the `@shlWithOverflow` builtin. /// `operand` is payload index to `OverflowArithmetic`. /// `small` is unused. @@ -3425,6 +3429,13 @@ pub const Inst = struct { ptr: Ref, }; + pub const MulCarryless = struct { + node: i32, + a: Ref, + b: Ref, + imm: Ref, + }; + pub const Cmpxchg = struct { node: i32, ptr: Ref, diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index 2edc6cb7f9c2..9d71d9f210e9 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -709,6 +709,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .mul_with_overflow => try self.airMulWithOverflow(inst), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .cmp_lt => try self.airCmp(inst, .lt), .cmp_lte => try self.airCmp(inst, .lte), .cmp_eq => try self.airCmp(inst, .eq), diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 6125ef191462..a16a688618ba 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -619,6 +619,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .mul_with_overflow => try self.airMulWithOverflow(inst), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .cmp_lt => try self.airCmp(inst, .lt), .cmp_lte => try self.airCmp(inst, .lte), .cmp_eq => try self.airCmp(inst, .eq), diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 6a54ffeea26b..2aeff08bdeed 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -524,6 +524,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .mul_with_overflow => try self.airMulWithOverflow(inst), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .div_float, .div_trunc, .div_floor, .div_exact => try self.airDiv(inst), .cmp_lt => try self.airCmp(inst, .lt), diff --git a/src/arch/sparc64/CodeGen.zig b/src/arch/sparc64/CodeGen.zig index 604fd3e69fdf..1321cccdd357 100644 --- a/src/arch/sparc64/CodeGen.zig +++ b/src/arch/sparc64/CodeGen.zig @@ -548,6 +548,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .mul_with_overflow => try self.airMulWithOverflow(inst), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .div_float, .div_trunc, .div_floor, .div_exact => try self.airDiv(inst), .cmp_lt => try self.airCmp(inst, .lt), diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index faed432a384d..8438b3266d18 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1717,6 +1717,8 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .shl_with_overflow => func.airShlWithOverflow(inst), .mul_with_overflow => func.airMulWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .clz => func.airClz(inst), .ctz => func.airCtz(inst), diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index cd36642b03f0..384722935d1c 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -619,6 +619,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .mul_with_overflow => try self.airMulWithOverflow(inst), .shl_with_overflow => try self.airAddSubShlWithOverflow(inst), + .mul_carryless => @panic("TODO"), + .div_float, .div_trunc, .div_floor, .div_exact => try self.airMulDivBinOp(inst), .cmp_lt => try self.airCmp(inst, .lt), diff --git a/src/codegen/c.zig b/src/codegen/c.zig index b78c9e9b1ad0..12ccb1541d96 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -2732,6 +2732,8 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .mul_with_overflow => try airOverflow(f, inst, "mul", .Bits), .shl_with_overflow => try airOverflow(f, inst, "shl", .Bits), + .mul_carryless => return f.fail("TODO: C backend: implement @mulCarryless()", .{}), + .min => try airMinMax(f, inst, '<', "fmin"), .max => try airMinMax(f, inst, '>', "fmax"), diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 7d70a5166664..04f0604873d7 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -4495,6 +4495,8 @@ pub const FuncGen = struct { .mul_with_overflow => try self.airOverflow(inst, "llvm.smul.with.overflow", "llvm.umul.with.overflow"), .shl_with_overflow => try self.airShlWithOverflow(inst), + .mul_carryless => try self.airMulCarryless(inst), + .bit_and, .bool_and => try self.airAnd(inst), .bit_or, .bool_or => try self.airOr(inst), .xor => try self.airXor(inst), @@ -8900,6 +8902,17 @@ pub const FuncGen = struct { return self.builder.buildVectorSplat(len, scalar, ""); } + fn airMulCarryless(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value { + if (self.liveness.isUnused(inst)) return null; + + const ty_pl = self.air.instructions.items(.data)[inst].ty_pl; + const extra = self.air.extraData(Air.MulCarryless, ty_pl.payload).data; + const a = try self.resolveInst(extra.a); + const b = try self.resolveInst(extra.b); + const imm = try self.resolveInst(extra.imm); + return self.builder.buildMulcl(a, b, imm, ""); + } + fn airSelect(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value { if (self.liveness.isUnused(inst)) return null; diff --git a/src/codegen/llvm/bindings.zig b/src/codegen/llvm/bindings.zig index 90d0f51c7b36..f9677d3ce3e3 100644 --- a/src/codegen/llvm/bindings.zig +++ b/src/codegen/llvm/bindings.zig @@ -736,6 +736,15 @@ pub const Builder = opaque { Name: [*:0]const u8, ) *Value; + pub const buildMulcl = ZigLLVMBuildMulcl; + extern fn ZigLLVMBuildMulcl( + *Builder, + A: *Value, + B: *Value, + IMM: *Value, + Name: [*:0]const u8, + ) *Value; + pub const buildPtrToInt = LLVMBuildPtrToInt; extern fn LLVMBuildPtrToInt( *Builder, diff --git a/src/print_air.zig b/src/print_air.zig index 671f781e5ed7..6e43a64b6dc6 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -306,6 +306,7 @@ const Writer = struct { .shuffle => try w.writeShuffle(s, inst), .reduce, .reduce_optimized => try w.writeReduce(s, inst), .cmp_vector, .cmp_vector_optimized => try w.writeCmpVector(s, inst), + .mul_carryless => try w.writeMulCarryless(s, inst), .dbg_block_begin, .dbg_block_end => {}, } @@ -461,6 +462,20 @@ const Writer = struct { try w.writeOperand(s, inst, 2, extra.rhs); } + fn writeMulCarryless(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { + const ty_pl = w.air.instructions.items(.data)[inst].ty_pl; + const extra = w.air.extraData(Air.MulCarryless, ty_pl.payload).data; + + const ty = w.air.getRefType(ty_pl.ty); + try w.writeType(s, ty); + try s.writeAll(", "); + try w.writeOperand(s, inst, 1, extra.a); + try s.writeAll(", "); + try w.writeOperand(s, inst, 2, extra.b); + try s.writeAll(", "); + try w.writeOperand(s, inst, 3, extra.imm); + } + fn writeReduce(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { const reduce = w.air.instructions.items(.data)[inst].reduce; diff --git a/src/print_zir.zig b/src/print_zir.zig index 542f0e977d6e..87fc3810c504 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -539,6 +539,7 @@ const Writer = struct { }, .builtin_async_call => try self.writeBuiltinAsyncCall(stream, extended), .cmpxchg => try self.writeCmpxchg(stream, extended), + .mul_carryless => try self.writeMulCarryless(stream, extended), } } @@ -1161,6 +1162,19 @@ const Writer = struct { try self.writeSrc(stream, src); } + fn writeMulCarryless(self: *Writer, stream: anytype, extended: Zir.Inst.Extended.InstData) !void { + const extra = self.code.extraData(Zir.Inst.MulCarryless, extended.operand).data; + const src = LazySrcLoc.nodeOffset(extra.node); + + try self.writeInstRef(stream, extra.a); + try stream.writeAll(", "); + try self.writeInstRef(stream, extra.b); + try stream.writeAll(", "); + try self.writeInstRef(stream, extra.imm); + try stream.writeAll(")) "); + try self.writeSrc(stream, src); + } + fn writeCall(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void { const inst_data = self.code.instructions.items(.data)[inst].pl_node; const extra = self.code.extraData(Zir.Inst.Call, inst_data.payload_index); diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index c38e311f6790..e4251d12533c 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -555,6 +556,13 @@ LLVMValueRef LLVMBuildVectorSplat(LLVMBuilderRef B, unsigned elem_count, LLVMVal return wrap(unwrap(B)->CreateVectorSplat(elem_count, unwrap(V), Name)); } +LLVMValueRef ZigLLVMBuildMulcl(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, LLVMValueRef IMM, const char *name) { + llvm::Value* values[3] = {unwrap(LHS), unwrap(RHS), unwrap(IMM)}; + + CallInst *call_inst = unwrap(B)->CreateIntrinsic(Intrinsic::X86Intrinsics::x86_pclmulqdq, llvm::None, values, nullptr, name); + return wrap(call_inst); +} + void ZigLLVMFnSetSubprogram(LLVMValueRef fn, ZigLLVMDISubprogram *subprogram) { assert( isa(unwrap(fn)) ); Function *unwrapped_function = reinterpret_cast(unwrap(fn)); diff --git a/src/zig_llvm.h b/src/zig_llvm.h index 7f9bd0a1619d..f001b57af89a 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -153,6 +153,7 @@ ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUMulFixSat(LLVMBuilderRef B, LLVMValueRef ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUShlSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSShlSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); ZIG_EXTERN_C LLVMValueRef LLVMBuildVectorSplat(LLVMBuilderRef B, unsigned elem_count, LLVMValueRef V, const char *Name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildMulcl(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, LLVMValueRef IMM, const char *name); ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, diff --git a/test/behavior/mul_carryless.zig b/test/behavior/mul_carryless.zig new file mode 100644 index 000000000000..4d2aa561789e --- /dev/null +++ b/test/behavior/mul_carryless.zig @@ -0,0 +1,18 @@ +const std = @import("std"); +const u64x2 = std.meta.Vector(2, u64); + +test "carryless mul" { + const S = struct { + fn doTheTest() !void { + const a = 0b10100010; + const b = 0b10010110; + const expected = @as(u64, 0b101100011101100); + const av: u64x2 = .{ a, 0 }; + const bv: u64x2 = .{ b, 0 }; + const r = @mulCarryless(av, bv, @as(u8, 0)); + try std.testing.expectEqual(expected, r[0]); + } + }; + try S.doTheTest(); + // comptime try S.doTheTest(); +}