Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mono] Add SIMD intrinsic for Vector64/128 comparisons #65128

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 76 additions & 73 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -9478,79 +9478,6 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
values [ins->dreg] = LLVMBuildSExt (builder, cmp, LLVMTypeOf (lhs), "");
break;
}
case OP_XEQUAL: {
LLVMTypeRef t;
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
int nelems;

#if defined(TARGET_WASM)
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
LLVMValueRef val;
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
val = LLVMIsNull (lhs) ? rhs : lhs;
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));

IntrinsicId intrins = (IntrinsicId)0;
switch (nelems) {
case 16:
intrins = INTRINS_WASM_ANYTRUE_V16;
break;
case 8:
intrins = INTRINS_WASM_ANYTRUE_V8;
break;
case 4:
intrins = INTRINS_WASM_ANYTRUE_V4;
break;
case 2:
intrins = INTRINS_WASM_ANYTRUE_V2;
break;
default:
g_assert_not_reached ();
}
/* res = !wasm.anytrue (val) */
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
break;
}
#endif
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));

//%c = icmp sgt <16 x i8> %a0, %a1
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
else
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));

LLVMTypeRef elemt;
if (srcelemt == LLVMDoubleType ())
elemt = LLVMInt64Type ();
else if (srcelemt == LLVMFloatType ())
elemt = LLVMInt32Type ();
else
elemt = srcelemt;

t = LLVMVectorType (elemt, nelems);
cmp = LLVMBuildSExt (builder, cmp, t, "");
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
int half = nelems / 2;
while (half >= 1) {
// AND the top and bottom halfes into the bottom half
for (int i = 0; i < half; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
for (int i = half; i < nelems; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
half = half / 2;
}
// Extract [0]
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
// convert to 0/1
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
break;
}
case OP_POPCNT32:
values [ins->dreg] = call_intrins (ctx, INTRINS_CTPOP_I32, &lhs, "");
break;
Expand Down Expand Up @@ -9629,6 +9556,82 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
}
#endif

#if defined(TARGET_ARM64) || defined(TARGET_X86) || defined(TARGET_AMD64) || defined(TARGET_WASM)
case OP_XEQUAL: {
LLVMTypeRef t;
LLVMValueRef cmp, mask [MAX_VECTOR_ELEMS], shuffle;
int nelems;

#if defined(TARGET_WASM)
/* The wasm code generator doesn't understand the shuffle/and code sequence below */
LLVMValueRef val;
if (LLVMIsNull (lhs) || LLVMIsNull (rhs)) {
val = LLVMIsNull (lhs) ? rhs : lhs;
nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));

IntrinsicId intrins = (IntrinsicId)0;
switch (nelems) {
case 16:
intrins = INTRINS_WASM_ANYTRUE_V16;
break;
case 8:
intrins = INTRINS_WASM_ANYTRUE_V8;
break;
case 4:
intrins = INTRINS_WASM_ANYTRUE_V4;
break;
case 2:
intrins = INTRINS_WASM_ANYTRUE_V2;
break;
default:
g_assert_not_reached ();
}
/* res = !wasm.anytrue (val) */
values [ins->dreg] = call_intrins (ctx, intrins, &val, "");
values [ins->dreg] = LLVMBuildZExt (builder, LLVMBuildICmp (builder, LLVMIntEQ, values [ins->dreg], LLVMConstInt (LLVMInt32Type (), 0, FALSE), ""), LLVMInt32Type (), dname);
break;
}
#endif
LLVMTypeRef srcelemt = LLVMGetElementType (LLVMTypeOf (lhs));

//%c = icmp sgt <16 x i8> %a0, %a1
if (srcelemt == LLVMDoubleType () || srcelemt == LLVMFloatType ())
cmp = LLVMBuildFCmp (builder, LLVMRealOEQ, lhs, rhs, "");
else
cmp = LLVMBuildICmp (builder, LLVMIntEQ, lhs, rhs, "");
nelems = LLVMGetVectorSize (LLVMTypeOf (cmp));

LLVMTypeRef elemt;
if (srcelemt == LLVMDoubleType ())
elemt = LLVMInt64Type ();
else if (srcelemt == LLVMFloatType ())
elemt = LLVMInt32Type ();
else
elemt = srcelemt;

t = LLVMVectorType (elemt, nelems);
cmp = LLVMBuildSExt (builder, cmp, t, "");
// cmp is a <nelems x elemt> vector, each element is either 0xff... or 0
int half = nelems / 2;
while (half >= 1) {
// AND the top and bottom halfes into the bottom half
for (int i = 0; i < half; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), half + i, FALSE);
for (int i = half; i < nelems; ++i)
mask [i] = LLVMConstInt (LLVMInt32Type (), 0, FALSE);
shuffle = LLVMBuildShuffleVector (builder, cmp, LLVMGetUndef (t), LLVMConstVector (mask, LLVMGetVectorSize (t)), "");
cmp = LLVMBuildAnd (builder, cmp, shuffle, "");
half = half / 2;
}
// Extract [0]
LLVMValueRef first_elem = LLVMBuildExtractElement (builder, cmp, LLVMConstInt (LLVMInt32Type (), 0, FALSE), "");
// convert to 0/1
LLVMValueRef cmp_zero = LLVMBuildICmp (builder, LLVMIntNE, first_elem, LLVMConstInt (elemt, 0, FALSE), "");
values [ins->dreg] = LLVMBuildZExt (builder, cmp_zero, LLVMInt8Type (), "");
break;
}
#endif

#if defined(TARGET_ARM64)

case OP_XOP_I4_I4:
Expand Down
112 changes: 100 additions & 12 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,29 @@ emit_xcompare (MonoCompile *cfg, MonoClass *klass, MonoTypeEnum etype, MonoInst
return ins;
}

static MonoInst*
emit_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
{
return emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
}

static MonoInst*
emit_not_xequal (MonoCompile *cfg, MonoClass *klass, MonoInst *arg1, MonoInst *arg2)
{
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XEQUAL, arg1->dreg, arg2->dreg);
int sreg = ins->dreg;
int dreg = alloc_ireg (cfg);
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
return ins;
}

static MonoInst*
emit_xzero (MonoCompile *cfg, MonoClass *klass)
{
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
}

static gboolean
is_intrinsics_vector_type (MonoType *vector_type)
{
Expand Down Expand Up @@ -492,7 +515,7 @@ emit_vector_create_elementwise (
{
int op = type_to_insert_op (etype);
MonoClass *vklass = mono_class_from_mono_type_internal (vtype);
MonoInst *ins = emit_simd_ins (cfg, vklass, OP_XZERO, -1, -1);
MonoInst *ins = emit_xzero (cfg, vklass);
for (int i = 0; i < fsig->param_count; ++i) {
ins = emit_simd_ins (cfg, vklass, op, ins->dreg, args [i]->dreg);
ins->inst_c0 = i;
Expand Down Expand Up @@ -590,10 +613,17 @@ static guint16 sri_vector_methods [] = {
SN_CreateScalar,
SN_CreateScalarUnsafe,
SN_Divide,
SN_Equals,
SN_EqualsAll,
SN_EqualsAny,
SN_Floor,
SN_GetElement,
SN_GetLower,
SN_GetUpper,
SN_GreaterThan,
SN_GreaterThanOrEqual,
SN_LessThan,
SN_LessThanOrEqual,
SN_Max,
SN_Min,
SN_Multiply,
Expand Down Expand Up @@ -788,6 +818,27 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR, -1, arg0_type, fsig, args);
case SN_CreateScalarUnsafe:
return emit_simd_ins_for_sig (cfg, klass, OP_CREATE_SCALAR_UNSAFE, -1, arg0_type, fsig, args);
case SN_Equals:
case SN_EqualsAll:
case SN_EqualsAny: {
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start to add comprehensive vector element type checks for each case. (Both input args and the return.) I have a PR up to refactor the code of type checks for this function. (#65486) You could merge this PR like this and fix it with your next PR.

if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
return NULL;

switch (id) {
case SN_Equals:
return emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
case SN_EqualsAll:
return emit_xequal (cfg, klass, args [0], args [1]);
case SN_EqualsAny: {
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
MonoInst *cmp_eq = emit_xcompare (cfg, arg_class, arg0_type, args [0], args [1]);
MonoInst *zero = emit_xzero (cfg, arg_class);
return emit_not_xequal (cfg, arg_class, cmp_eq, zero);
}
default: g_assert_not_reached ();
}
}
case SN_GetElement: {
MonoClass *arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
MonoType *etype = mono_class_get_context (arg_class)->class_inst->type_argv [0];
Expand All @@ -809,6 +860,34 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
int op = id == SN_GetLower ? OP_XLOWER : OP_XUPPER;
return emit_simd_ins_for_sig (cfg, klass, op, 0, arg0_type, fsig, args);
}
case SN_GreaterThan:
case SN_GreaterThanOrEqual:
case SN_LessThan:
case SN_LessThanOrEqual: {
MonoType *arg_type = get_vector_t_elem_type (fsig->params [0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

if (!MONO_TYPE_IS_INTRINSICS_VECTOR_PRIMITIVE (arg_type))
return NULL;

gboolean is_unsigned = type_is_unsigned (fsig->params [0]);
MonoInst *ins = emit_xcompare (cfg, klass, arg0_type, args [0], args [1]);
switch (id) {
case SN_GreaterThan:
ins->inst_c0 = is_unsigned ? CMP_GT_UN : CMP_GT;
break;
case SN_GreaterThanOrEqual:
ins->inst_c0 = is_unsigned ? CMP_GE_UN : CMP_GE;
break;
case SN_LessThan:
ins->inst_c0 = is_unsigned ? CMP_LT_UN : CMP_LT;
break;
case SN_LessThanOrEqual:
ins->inst_c0 = is_unsigned ? CMP_LE_UN : CMP_LE;
break;
default:
g_assert_not_reached ();
}
return ins;
}
case SN_Negate:
case SN_OnesComplement: {
#ifdef TARGET_ARM64
Expand Down Expand Up @@ -879,6 +958,8 @@ static guint16 vector64_vector128_t_methods [] = {
SN_get_Count,
SN_get_IsSupported,
SN_get_Zero,
SN_op_Equality,
SN_op_Inequality,
};

static MonoInst*
Expand Down Expand Up @@ -928,10 +1009,10 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
return ins;
}
case SN_get_Zero: {
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
return emit_xzero (cfg, klass);
}
case SN_get_AllBitsSet: {
MonoInst *ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
MonoInst *ins = emit_xzero (cfg, klass);
return emit_xcompare (cfg, klass, etype->type, ins, ins);
}
case SN_Equals: {
Expand All @@ -941,6 +1022,16 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
}
break;
}
case SN_op_Equality:
case SN_op_Inequality:
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
mono_metadata_type_equal (fsig->params [0], type) &&
mono_metadata_type_equal (fsig->params [1], type));
switch (id) {
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
default: g_assert_not_reached ();
}
default:
break;
}
Expand Down Expand Up @@ -1086,7 +1177,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
return ins;
case SN_get_Zero:
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
return emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
return emit_xzero (cfg, klass);
case SN_get_One: {
g_assert (fsig->param_count == 0 && mono_metadata_type_equal (fsig->ret, type));
MonoInst *one = NULL;
Expand Down Expand Up @@ -1115,7 +1206,7 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
}
case SN_get_AllBitsSet: {
/* Compare a zero vector with itself */
ins = emit_simd_ins (cfg, klass, OP_XZERO, -1, -1);
ins = emit_xzero (cfg, klass);
return emit_xcompare (cfg, klass, etype->type, ins, ins);
}
case SN_get_Item: {
Expand Down Expand Up @@ -1222,14 +1313,11 @@ emit_sys_numerics_vector_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSig
g_assert (fsig->param_count == 2 && fsig->ret->type == MONO_TYPE_BOOLEAN &&
mono_metadata_type_equal (fsig->params [0], type) &&
mono_metadata_type_equal (fsig->params [1], type));
ins = emit_simd_ins (cfg, klass, OP_XEQUAL, args [0]->dreg, args [1]->dreg);
if (id == SN_op_Inequality) {
int sreg = ins->dreg;
int dreg = alloc_ireg (cfg);
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, sreg, 0);
EMIT_NEW_UNALU (cfg, ins, OP_CEQ, dreg, -1);
switch (id) {
case SN_op_Equality: return emit_xequal (cfg, klass, args [0], args [1]);
case SN_op_Inequality: return emit_not_xequal (cfg, klass, args [0], args [1]);
default: g_assert_not_reached ();
}
return ins;
case SN_GreaterThan:
case SN_GreaterThanOrEqual:
case SN_LessThan:
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/mini/simd-methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ METHOD(Create)
METHOD(CreateScalar)
METHOD(CreateScalarUnsafe)
METHOD(ConditionalSelect)
METHOD(EqualsAll)
METHOD(EqualsAny)
METHOD(GetElement)
METHOD(GetLower)
METHOD(GetUpper)
Expand Down