Skip to content

Commit afc2f11

Browse files
[Mono] Add SIMD intrinsic for Vector64/128 comparisons (#65128)
* Add vector comparison intrinsics * Add EqualsAll and EqualsAny intrinsics * Remove broken EqualsAny * Fix EqualsAny * Enable xequal also for arm64 * Fix xzero type * Fix bad merge * Add guards for invalid types * Revert unrelated change * Extract duplicate code blocks to a new function * Fix EqualsAny * Fix typo + code improvements
1 parent 9012b23 commit afc2f11

File tree

3 files changed

+178
-85
lines changed

3 files changed

+178
-85
lines changed

src/mono/mono/mini/mini-llvm.c

Lines changed: 76 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9465,79 +9465,6 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
94659465
values [ins->dreg] = LLVMBuildSExt (builder, cmp, LLVMTypeOf (lhs), "");
94669466
break;
94679467
}
9468-
case OP_XEQUAL: {
9469-
LLVMTypeRef t;
9470-
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
9471-
int nelems;
9472-
9473-
#if defined(TARGET_WASM)
9474-
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
9475-
LLVMValueRef val;
9476-
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
9477-
val = LLVMIsNull (lhs) ? rhs : lhs;
9478-
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));
9479-
9480-
IntrinsicId intrins = (IntrinsicId)0;
9481-
switch (nelems) {
9482-
case 16:
9483-
intrins = INTRINS_WASM_ANYTRUE_V16;
9484-
break;
9485-
case 8:
9486-
intrins = INTRINS_WASM_ANYTRUE_V8;
9487-
break;
9488-
case 4:
9489-
intrins = INTRINS_WASM_ANYTRUE_V4;
9490-
break;
9491-
case 2:
9492-
intrins = INTRINS_WASM_ANYTRUE_V2;
9493-
break;
9494-
default:
9495-
g_assert_not_reached ();
9496-
}
9497-
/* res = !wasm.anytrue (val) */
9498-
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
9499-
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
9500-
break;
9501-
}
9502-
#endif
9503-
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));
9504-
9505-
//%c = icmp sgt <16 x i8> %a0, %a1
9506-
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
9507-
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
9508-
else
9509-
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
9510-
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));
9511-
9512-
LLVMTypeRef elemt;
9513-
if (srcelemt == LLVMDoubleType ())
9514-
elemt = LLVMInt64Type ();
9515-
else if (srcelemt == LLVMFloatType ())
9516-
elemt = LLVMInt32Type ();
9517-
else
9518-
elemt = srcelemt;
9519-
9520-
t = LLVMVectorType (elemt, nelems);
9521-
cmp = LLVMBuildSExt (builder, cmp, t, "");
9522-
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
9523-
int half = nelems / 2;
9524-
while (half >= 1) {
9525-
// AND the top and bottom halfes into the bottom half
9526-
for (int i = 0; i < half; ++i)
9527-
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
9528-
for (int i = half; i < nelems; ++i)
9529-
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
9530-
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
9531-
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
9532-
half = half / 2;
9533-
}
9534-
// Extract [0]
9535-
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
9536-
// convert to 0/1
9537-
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
9538-
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
9539-
break;
9540-
}
95419468
case OP_POPCNT32:
95429469
values [ins->dreg] = call_intrins (ctx, INTRINS_CTPOP_I32, &lhs, "");
95439470
break;
@@ -9616,6 +9543,82 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
96169543
}
96179544
#endif
96189545

9546+
#if defined(TARGET_ARM64) || defined(TARGET_X86) || defined(TARGET_AMD64) || defined(TARGET_WASM)
9547+
case OP_XEQUAL: {
9548+
LLVMTypeRef t;
9549+
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
9550+
int nelems;
9551+
9552+
#if defined(TARGET_WASM)
9553+
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
9554+
LLVMValueRef val;
9555+
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
9556+
val = LLVMIsNull (lhs) ? rhs : lhs;
9557+
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));
9558+
9559+
IntrinsicId intrins = (IntrinsicId)0;
9560+
switch (nelems) {
9561+
case 16:
9562+
intrins = INTRINS_WASM_ANYTRUE_V16;
9563+
break;
9564+
case 8:
9565+
intrins = INTRINS_WASM_ANYTRUE_V8;
9566+
break;
9567+
case 4:
9568+
intrins = INTRINS_WASM_ANYTRUE_V4;
9569+
break;
9570+
case 2:
9571+
intrins = INTRINS_WASM_ANYTRUE_V2;
9572+
break;
9573+
default:
9574+
g_assert_not_reached ();
9575+
}
9576+
/* res = !wasm.anytrue (val) */
9577+
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
9578+
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
9579+
break;
9580+
}
9581+
#endif
9582+
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));
9583+
9584+
//%c = icmp sgt <16 x i8> %a0, %a1
9585+
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
9586+
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
9587+
else
9588+
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
9589+
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));
9590+
9591+
LLVMTypeRef elemt;
9592+
if (srcelemt == LLVMDoubleType ())
9593+
elemt = LLVMInt64Type ();
9594+
else if (srcelemt == LLVMFloatType ())
9595+
elemt = LLVMInt32Type ();
9596+
else
9597+
elemt = srcelemt;
9598+
9599+
t = LLVMVectorType (elemt, nelems);
9600+
cmp = LLVMBuildSExt (builder, cmp, t, "");
9601+
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
9602+
int half = nelems / 2;
9603+
while (half >= 1) {
9604+
// AND the top and bottom halfes into the bottom half
9605+
for (int i = 0; i < half; ++i)
9606+
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
9607+
for (int i = half; i < nelems; ++i)
9608+
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
9609+
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
9610+
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
9611+
half = half / 2;
9612+
}
9613+
// Extract [0]
9614+
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
9615+
// convert to 0/1
9616+
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
9617+
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
9618+
break;
9619+
}
9620+
#endif
9621+
96199622
#if defined(TARGET_ARM64)
96209623

96219624
case OP_XOP_I4_I4:

src/mono/mono/mini/simd-intrinsics.c

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,29 @@ emit_xcompare (MonoCompile *cfg, MonoClass *klass, MonoTypeEnum etype, MonoInst
260260
return ins;
261261
}
262262

263+
static MonoInst*
264+
emit_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
265+
{
266+
return emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
267+
}
268+
269+
static MonoInst*
270+
emit_not_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
271+
{
272+
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
273+
int sreg = ins->dreg;
274+
int dreg = alloc_ireg (cfg);
275+
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
276+
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
277+
return ins;
278+
}
279+
280+
static MonoInst*
281+
emit_xzero (MonoCompile *cfg, MonoClass *klass)
282+
{
283+
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
284+
}
285+
263286
static gboolean
264287
is_intrinsics_vector_type (MonoType *vector_type)
265288
{
@@ -492,7 +515,7 @@ emit_vector_create_elementwise (
492515
{
493516
int op = type_to_insert_op (etype);
494517
MonoClass *vklass = mono_class_from_mono_type_internal (vtype);
495-
MonoInst *ins = emit_simd_ins (cfg, vklass, OP_XZERO, -1, -1);
518+
MonoInst *ins = emit_xzero (cfg, vklass);
496519
for (int i = 0; i < fsig->param_count; ++i) {
497520
ins = emit_simd_ins (cfg, vklass, op, ins->dreg, args [i]->dreg);
498521
ins->inst_c0 = i;
@@ -590,10 +613,17 @@ static guint16 sri_vector_methods [] = {
590613
SN_CreateScalar,
591614
SN_CreateScalarUnsafe,
592615
SN_Divide,
616+
SN_Equals,
617+
SN_EqualsAll,
618+
SN_EqualsAny,
593619
SN_Floor,
594620
SN_GetElement,
595621
SN_GetLower,
596622
SN_GetUpper,
623+
SN_GreaterThan,
624+
SN_GreaterThanOrEqual,
625+
SN_LessThan,
626+
SN_LessThanOrEqual,
597627
SN_Max,
598628
SN_Min,
599629
SN_Multiply,
@@ -788,6 +818,27 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
788818
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR, -1, arg0_type, fsig, args);
789819
case SN_CreateScalarUnsafe:
790820
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR_UNSAFE, -1, arg0_type, fsig, args);
821+
case SN_Equals:
822+
case SN_EqualsAll:
823+
case SN_EqualsAny: {
824+
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
825+
if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
826+
return NULL;
827+
828+
switch (id) {
829+
case SN_Equals:
830+
return emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
831+
case SN_EqualsAll:
832+
return emit_xequal (cfg, klass, args [0], args [1]);
833+
case SN_EqualsAny: {
834+
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
835+
MonoInst *cmp_eq = emit_xcompare (cfg, arg_class, arg0_type, args [0], args [1]);
836+
MonoInst *zero = emit_xzero (cfg, arg_class);
837+
return emit_not_xequal (cfg, arg_class, cmp_eq, zero);
838+
}
839+
default: g_assert_not_reached ();
840+
}
841+
}
791842
case SN_GetElement: {
792843
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
793844
MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0];
@@ -809,6 +860,34 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
809860
int op = id == SN_GetLower ? OP_XLOWER : OP_XUPPER;
810861
return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
811862
}
863+
case SN_GreaterThan:
864+
case SN_GreaterThanOrEqual:
865+
case SN_LessThan:
866+
case SN_LessThanOrEqual: {
867+
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
868+
if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
869+
return NULL;
870+
871+
gboolean is_unsigned = type_is_unsigned (fsig->params [0]);
872+
MonoInst *ins = emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
873+
switch (id) {
874+
case SN_GreaterThan:
875+
ins->inst_c0 = is_unsigned ? CMP_GT_UN : CMP_GT;
876+
break;
877+
case SN_GreaterThanOrEqual:
878+
ins->inst_c0 = is_unsigned ? CMP_GE_UN : CMP_GE;
879+
break;
880+
case SN_LessThan:
881+
ins->inst_c0 = is_unsigned ? CMP_LT_UN : CMP_LT;
882+
break;
883+
case SN_LessThanOrEqual:
884+
ins->inst_c0 = is_unsigned ? CMP_LE_UN : CMP_LE;
885+
break;
886+
default:
887+
g_assert_not_reached ();
888+
}
889+
return ins;
890+
}
812891
case SN_Negate:
813892
case SN_OnesComplement: {
814893
#ifdef TARGET_ARM64
@@ -879,6 +958,8 @@ static guint16 vector64_vector128_t_methods [] = {
879958
SN_get_Count,
880959
SN_get_IsSupported,
881960
SN_get_Zero,
961+
SN_op_Equality,
962+
SN_op_Inequality,
882963
};
883964

884965
static MonoInst*
@@ -928,10 +1009,10 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
9281009
return ins;
9291010
}
9301011
case SN_get_Zero: {
931-
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
1012+
return emit_xzero (cfg, klass);
9321013
}
9331014
case SN_get_AllBitsSet: {
934-
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
1015+
MonoInst *ins = emit_xzero (cfg, klass);
9351016
return emit_xcompare (cfg, klass, etype->type, ins, ins);
9361017
}
9371018
case SN_Equals: {
@@ -941,6 +1022,16 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
9411022
}
9421023
break;
9431024
}
1025+
case SN_op_Equality:
1026+
case SN_op_Inequality:
1027+
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
1028+
mono_metadata_type_equal (fsig->params [0], type) &&
1029+
mono_metadata_type_equal (fsig->params [1], type));
1030+
switch (id) {
1031+
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
1032+
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
1033+
default: g_assert_not_reached ();
1034+
}
9441035
default:
9451036
break;
9461037
}
@@ -1086,7 +1177,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
10861177
return ins;
10871178
case SN_get_Zero:
10881179
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
1089-
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
1180+
return emit_xzero (cfg, klass);
10901181
case SN_get_One: {
10911182
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
10921183
MonoInst *one = NULL;
@@ -1115,7 +1206,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
11151206
}
11161207
case SN_get_AllBitsSet: {
11171208
/* Compare a zero vector with itself */
1118-
ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
1209+
ins = emit_xzero (cfg, klass);
11191210
return emit_xcompare (cfg, klass, etype->type, ins, ins);
11201211
}
11211212
case SN_get_Item: {
@@ -1222,14 +1313,11 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
12221313
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
12231314
mono_metadata_type_equal (fsig->params [0], type) &&
12241315
mono_metadata_type_equal (fsig->params [1], type));
1225-
ins = emit_simd_ins (cfg, klass, OP_XEQUAL, args [0]->dreg, args [1]->dreg);
1226-
if (id == SN_op_Inequality) {
1227-
int sreg = ins->dreg;
1228-
int dreg = alloc_ireg (cfg);
1229-
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
1230-
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
1316+
switch (id) {
1317+
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
1318+
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
1319+
default: g_assert_not_reached ();
12311320
}
1232-
return ins;
12331321
case SN_GreaterThan:
12341322
case SN_GreaterThanOrEqual:
12351323
case SN_LessThan:

src/mono/mono/mini/simd-methods.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ METHOD(Create)
6262
METHOD(CreateScalar)
6363
METHOD(CreateScalarUnsafe)
6464
METHOD(ConditionalSelect)
65+
METHOD(EqualsAll)
66+
METHOD(EqualsAny)
6567
METHOD(GetElement)
6668
METHOD(GetLower)
6769
METHOD(GetUpper)

0 commit comments

Comments
 (0)