Skip to content

stage1: Implement @reduce builtin for vector types #6558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/std/builtin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ pub const AtomicOrder = enum {
SeqCst,
};

/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const ReduceOp = enum {
And,
Or,
Xor,
Min,
Max,
Comment on lines +104 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower case tag names for new enums please. I normally don't want to care about style stuff but I also don't want to have to make unnecessary breaking changes later.

};

/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const AtomicRmwOp = enum {
Expand Down
26 changes: 26 additions & 0 deletions src/stage1/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,7 @@ enum BuiltinFnId {
BuiltinFnIdWasmMemorySize,
BuiltinFnIdWasmMemoryGrow,
BuiltinFnIdSrc,
BuiltinFnIdReduce,
};

struct BuiltinFnEntry {
Expand Down Expand Up @@ -2441,6 +2442,15 @@ enum AtomicOrder {
AtomicOrderSeqCst,
};

// synchronized with code in define_builtin_compile_vars
enum ReduceOp {
ReduceOp_and,
ReduceOp_or,
ReduceOp_xor,
ReduceOp_min,
ReduceOp_max,
};

// synchronized with the code in define_builtin_compile_vars
enum AtomicRmwOp {
AtomicRmwOp_xchg,
Expand Down Expand Up @@ -2550,6 +2560,7 @@ enum IrInstSrcId {
IrInstSrcIdEmbedFile,
IrInstSrcIdCmpxchg,
IrInstSrcIdFence,
IrInstSrcIdReduce,
IrInstSrcIdTruncate,
IrInstSrcIdIntCast,
IrInstSrcIdFloatCast,
Expand Down Expand Up @@ -2672,6 +2683,7 @@ enum IrInstGenId {
IrInstGenIdErrName,
IrInstGenIdCmpxchg,
IrInstGenIdFence,
IrInstGenIdReduce,
IrInstGenIdTruncate,
IrInstGenIdShuffleVector,
IrInstGenIdSplat,
Expand Down Expand Up @@ -3521,6 +3533,20 @@ struct IrInstGenFence {
AtomicOrder order;
};

struct IrInstSrcReduce {
IrInstSrc base;

IrInstSrc *op;
IrInstSrc *value;
};

struct IrInstGenReduce {
IrInstGen base;

ReduceOp op;
IrInstGen *value;
};

struct IrInstSrcTruncate {
IrInstSrc base;

Expand Down
95 changes: 56 additions & 39 deletions src/stage1/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}

enum class ScalarizePredicate {
// Returns true iff all the elements in the vector are 1.
// Equivalent to folding all the bits with `and`.
All,
// Returns true iff there's at least one element in the vector that is 1.
// Equivalent to folding all the bits with `or`.
Any,
};

// Collapses a <N x i1> vector into a single i1 according to the given predicate
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");

switch (predicate) {
case ScalarizePredicate::Any: {
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
}
case ScalarizePredicate::All: {
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
}

zig_unreachable();
}


static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMValueRef val1, LLVMValueRef val2)
{
Expand All @@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}

if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}

LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
Expand All @@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);

Expand All @@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand All @@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
ltz = ZigLLVMBuildOrReduce(g->builder, ltz);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);

Expand Down Expand Up @@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}

if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}

LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
Expand Down Expand Up @@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
if (rhs_type->id == ZigTypeIdVector) {
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit);
}
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);

Expand Down Expand Up @@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
return result_loc;
}

static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) {
LLVMValueRef value = ir_llvm_value(g, instruction->value);

ZigType *value_type = instruction->value->value->type;
assert(value_type->id == ZigTypeIdVector);
ZigType *scalar_type = value_type->data.vector.elem_type;

LLVMValueRef result_val;
switch (instruction->op) {
case ReduceOp_and:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildAndReduce(g->builder, value);
break;
case ReduceOp_or:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildOrReduce(g->builder, value);
break;
case ReduceOp_xor:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildXorReduce(g->builder, value);
break;
case ReduceOp_min: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMinReduce(g->builder, value);
} else zig_unreachable();
} break;
case ReduceOp_max: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
} else zig_unreachable();
} break;
default:
zig_unreachable();
}

return result_val;
}

static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) {
LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order);
LLVMBuildFence(g->builder, atomic_order, false, "");
Expand Down Expand Up @@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl
return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction);
case IrInstGenIdFence:
return ir_render_fence(g, executable, (IrInstGenFence *)instruction);
case IrInstGenIdReduce:
return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction);
case IrInstGenIdTruncate:
return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction);
case IrInstGenIdBoolNot:
Expand Down Expand Up @@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1);
create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2);
create_builtin_fn(g, BuiltinFnIdSrc, "src", 0);
create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2);
}

static const char *bool_to_str(bool b) {
Expand Down
Loading