From 2d751cac7311b344c237df3b3c63b33434b52217 Mon Sep 17 00:00:00 2001 From: Ruihan-Yin <107431934+Ruihan-Yin@users.noreply.github.com> Date: Fri, 19 Jan 2024 09:39:55 -0800 Subject: [PATCH] Enable EVEX feature: Embedded Rounding for Avx512F.Add() (#94684) * some workaround with embedded rounding in compiler backend. * extend _idEvexbContext to 2bit to distinguish embedded broadcast and embedded rounding * Expose APIs with rounding mode. * Apply format patch * Do not include the third parameter in Avx512.Add(left, right) * split _idEvexbContext bits and made a explicit convert function from uint8_t to insOpts for embedded rounding mode. * Remove unexpected comment-out * Fix unexpected deletion * resolve comments: removed redundent bits in instDesc for EVEX.b context. Introduced `emitDispEmbRounding` to display the embedded rounding feature in the disassembly. * bug fix: fix un-needed assertion check. * Apply format patch. * Resolve comments: merge INS_OPTS_EVEX_b and INS_OPTS_EVEX_er_rd Do a pre-check for embedded rounding before lowering. * Add a helper function to generalize the logic when lowering the embedded rounding intrinsics. * Resolve comments: 1. fix typo in commnets 2. Add SetEvexBroadcastIfNeeded 3. Added mask in insOpts * 1. Add unit case for non-default rounding mode 2. removed round-to-even, the default option from InsOpts as it will be handled on the default path. * formatting * 1. Create a fallback jump table for embedded rounding APIs when control byte is not constant. 2. Create a template to generate the unit tests for embedded rounding APIs. 3. nit: fix naming. * remove hand-written unit tests for embedded rounding. * formatting * Resolve comments. * formatting * revert changes: let SetEmbRoundingMode accept unexpected values to accomadate the jump table generatation logics. --- src/coreclr/jit/codegen.h | 1 + src/coreclr/jit/emit.h | 53 ++- src/coreclr/jit/emitxarch.cpp | 128 +++++-- src/coreclr/jit/emitxarch.h | 37 +- src/coreclr/jit/gentree.h | 44 +++ src/coreclr/jit/hwintrinsic.h | 26 +- src/coreclr/jit/hwintrinsiccodegenxarch.cpp | 59 ++++ src/coreclr/jit/hwintrinsiclistxarch.h | 2 +- src/coreclr/jit/hwintrinsicxarch.cpp | 10 + src/coreclr/jit/instr.cpp | 4 +- src/coreclr/jit/instr.h | 10 +- src/coreclr/jit/lowerxarch.cpp | 62 ++++ src/coreclr/jit/lsraxarch.cpp | 7 + .../X86/Avx512F.PlatformNotSupported.cs | 5 + .../System/Runtime/Intrinsics/X86/Avx512F.cs | 5 + .../System/Runtime/Intrinsics/X86/Enums.cs | 20 ++ .../ref/System.Runtime.Intrinsics.cs | 8 + .../GenerateHWIntrinsicTests_X86.cs | 9 + .../Shared/SimpleBinOpEmbRounding.template | 315 ++++++++++++++++++ 19 files changed, 760 insertions(+), 45 deletions(-) create mode 100644 src/tests/JIT/HardwareIntrinsics/X86/Shared/SimpleBinOpEmbRounding.template diff --git a/src/coreclr/jit/codegen.h b/src/coreclr/jit/codegen.h index 22f4e58b094a2c..e1af485a0eb4f6 100644 --- a/src/coreclr/jit/codegen.h +++ b/src/coreclr/jit/codegen.h @@ -971,6 +971,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX void genHWIntrinsic_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp); void genHWIntrinsic_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival); void genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr); + void genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival); void genHWIntrinsic_R_R_RM( GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, GenTree* op2); void genHWIntrinsic_R_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival); diff --git a/src/coreclr/jit/emit.h b/src/coreclr/jit/emit.h index 51ed6b72a0c520..08d91386d8bb22 100644 --- a/src/coreclr/jit/emit.h +++ b/src/coreclr/jit/emit.h @@ -774,8 +774,12 @@ class emitter unsigned _idCallAddr : 1; // IL indirect calls: can make a direct call to iiaAddr unsigned _idNoGC : 1; // Some helpers don't get recorded in GC tables #if defined(TARGET_XARCH) - unsigned _idEvexbContext : 1; // does EVEX.b need to be set. -#endif // TARGET_XARCH + // EVEX.b can indicate several context: embedded broadcast, embedded rounding. + // For normal and embedded broadcast intrinsics, EVEX.L'L has the same semantic, vector length. + // For embedded rounding, EVEX.L'L semantic changes to indicate the rounding mode. + // Multiple bits in _idEvexbContext are used to inform emitter to specially handle the EVEX.L'L bits. + unsigned _idEvexbContext : 2; +#endif // TARGET_XARCH #ifdef TARGET_ARM64 @@ -808,8 +812,8 @@ class emitter //////////////////////////////////////////////////////////////////////// // Space taken up to here: - // x86: 47 bits - // amd64: 47 bits + // x86: 48 bits + // amd64: 48 bits // arm: 48 bits // arm64: 53 bits // loongarch64: 46 bits @@ -828,7 +832,7 @@ class emitter #elif defined(TARGET_LOONGARCH64) || defined(TARGET_RISCV64) #define ID_EXTRA_BITFIELD_BITS (14) #elif defined(TARGET_XARCH) -#define ID_EXTRA_BITFIELD_BITS (15) +#define ID_EXTRA_BITFIELD_BITS (16) #else #error Unsupported or unset target architecture #endif @@ -863,8 +867,8 @@ class emitter //////////////////////////////////////////////////////////////////////// // Space taken up to here (with/without prev offset, assuming host==target): - // x86: 53/49 bits - // amd64: 54/49 bits + // x86: 54/50 bits + // amd64: 55/50 bits // arm: 54/50 bits // arm64: 60/55 bits // loongarch64: 53/48 bits @@ -880,8 +884,8 @@ class emitter //////////////////////////////////////////////////////////////////////// // Small constant size (with/without prev offset, assuming host==target): - // x86: 11/15 bits - // amd64: 10/15 bits + // x86: 10/14 bits + // amd64: 9/14 bits // arm: 10/14 bits // arm64: 4/9 bits // loongarch64: 11/16 bits @@ -1578,15 +1582,35 @@ class emitter } #ifdef TARGET_XARCH - bool idIsEvexbContext() const + bool idIsEvexbContextSet() const { return _idEvexbContext != 0; } - void idSetEvexbContext() + + void idSetEvexbContext(insOpts instOptions) { assert(_idEvexbContext == 0); - _idEvexbContext = 1; - assert(_idEvexbContext == 1); + if (instOptions == INS_OPTS_EVEX_eb_er_rd) + { + _idEvexbContext = 1; + } + else if (instOptions == INS_OPTS_EVEX_er_ru) + { + _idEvexbContext = 2; + } + else if (instOptions == INS_OPTS_EVEX_er_rz) + { + _idEvexbContext = 3; + } + else + { + unreached(); + } + } + + unsigned idGetEvexbContext() const + { + return _idEvexbContext; } #endif @@ -2166,6 +2190,7 @@ class emitter void emitDispInsOffs(unsigned offs, bool doffs); void emitDispInsHex(instrDesc* id, BYTE* code, size_t sz); void emitDispEmbBroadcastCount(instrDesc* id); + void emitDispEmbRounding(instrDesc* id); void emitDispIns(instrDesc* id, bool isNew, bool doffs, @@ -3814,7 +3839,7 @@ inline unsigned emitter::emitGetInsCIargs(instrDesc* id) // emitAttr emitter::emitGetMemOpSize(instrDesc* id) const { - if (id->idIsEvexbContext()) + if (id->idIsEvexbContextSet()) { // should have the assumption that Evex.b now stands for the embedded broadcast context. // reference: Section 2.7.5 in Intel 64 and ia-32 architectures software developer's manual volume 2. diff --git a/src/coreclr/jit/emitxarch.cpp b/src/coreclr/jit/emitxarch.cpp index 9bcb574578cc6e..c281b2d08c3cb3 100644 --- a/src/coreclr/jit/emitxarch.cpp +++ b/src/coreclr/jit/emitxarch.cpp @@ -1139,6 +1139,30 @@ static bool isLowSimdReg(regNumber reg) #endif } +//------------------------------------------------------------------------ +// GetEmbRoundingMode: Get the rounding mode for embedded rounding +// +// Arguments: +// mode -- the flag from the corresponding GenTree node indicating the mode. +// +// Return Value: +// the instruction option carrying the rounding mode information. +// +insOpts emitter::GetEmbRoundingMode(uint8_t mode) const +{ + switch (mode) + { + case 1: + return INS_OPTS_EVEX_eb_er_rd; + case 2: + return INS_OPTS_EVEX_er_ru; + case 3: + return INS_OPTS_EVEX_er_rz; + default: + unreached(); + } +} + //------------------------------------------------------------------------ // encodeRegAsIval: Encodes a register as an ival for use by a SIMD instruction // @@ -1309,18 +1333,50 @@ emitter::code_t emitter::AddEvexPrefix(const instrDesc* id, code_t code, emitAtt if (attr == EA_32BYTE) { - // Set L bit to 1 in case of instructions that operate on 256-bits. + // Set EVEX.L'L bits to 01 in case of instructions that operate on 256-bits. code |= LBIT_IN_BYTE_EVEX_PREFIX; } else if (attr == EA_64BYTE) { - // Set L' bits to 11 in case of instructions that operate on 512-bits. + // Set EVEX.L'L bits to 10 in case of instructions that operate on 512-bits. code |= LPRIMEBIT_IN_BYTE_EVEX_PREFIX; } - if (id->idIsEvexbContext()) + if (id->idIsEvexbContextSet()) { code |= EVEX_B_BIT; + + if (!id->idHasMem()) + { + // embedded rounding case. + unsigned roundingMode = id->idGetEvexbContext(); + if (roundingMode == 1) + { + // {rd-sae} + code &= ~(LPRIMEBIT_IN_BYTE_EVEX_PREFIX); + code |= LBIT_IN_BYTE_EVEX_PREFIX; + } + else if (roundingMode == 2) + { + // {ru-sae} + code |= LPRIMEBIT_IN_BYTE_EVEX_PREFIX; + code &= ~(LBIT_IN_BYTE_EVEX_PREFIX); + } + else if (roundingMode == 3) + { + // {rz-sae} + code |= LPRIMEBIT_IN_BYTE_EVEX_PREFIX; + code |= LBIT_IN_BYTE_EVEX_PREFIX; + } + else + { + unreached(); + } + } + else + { + assert(id->idGetEvexbContext() == 1); + } } regNumber maskReg = REG_NA; @@ -6742,11 +6798,7 @@ void emitter::emitIns_R_R_A( id->idIns(ins); id->idReg1(reg1); id->idReg2(reg2); - if (instOptions == INS_OPTS_EVEX_b) - { - assert(UseEvexEncoding()); - id->idSetEvexbContext(); - } + SetEvexBroadcastIfNeeded(id, instOptions); emitHandleMemOp(indir, id, (ins == INS_mulx) ? IF_RWR_RWR_ARD : emitInsModeFormat(ins, IF_RRD_RRD_ARD), ins); @@ -6871,11 +6923,7 @@ void emitter::emitIns_R_R_C(instruction ins, id->idReg1(reg1); id->idReg2(reg2); id->idAddr()->iiaFieldHnd = fldHnd; - if (instOptions == INS_OPTS_EVEX_b) - { - assert(UseEvexEncoding()); - id->idSetEvexbContext(); - } + SetEvexBroadcastIfNeeded(id, instOptions); UNATIVE_OFFSET sz = emitInsSizeCV(id, insCodeRM(ins)); id->idCodeSize(sz); @@ -6889,7 +6937,8 @@ void emitter::emitIns_R_R_C(instruction ins, * Add an instruction with three register operands. */ -void emitter::emitIns_R_R_R(instruction ins, emitAttr attr, regNumber targetReg, regNumber reg1, regNumber reg2) +void emitter::emitIns_R_R_R( + instruction ins, emitAttr attr, regNumber targetReg, regNumber reg1, regNumber reg2, insOpts instOptions) { assert(IsAvx512OrPriorInstruction(ins)); assert(IsThreeOperandAVXInstruction(ins) || IsKInstruction(ins)); @@ -6901,6 +6950,13 @@ void emitter::emitIns_R_R_R(instruction ins, emitAttr attr, regNumber targetReg, id->idReg2(reg1); id->idReg3(reg2); + if ((instOptions & INS_OPTS_b_MASK) != INS_OPTS_NONE) + { + // if EVEX.b needs to be set in this path, then it should be embedded rounding. + assert(UseEvexEncoding()); + id->idSetEvexbContext(instOptions); + } + UNATIVE_OFFSET sz = emitInsSizeRR(id, insCodeRM(ins)); id->idCodeSize(sz); @@ -6921,12 +6977,8 @@ void emitter::emitIns_R_R_S( id->idReg1(reg1); id->idReg2(reg2); id->idAddr()->iiaLclVar.initLclVarAddr(varx, offs); + SetEvexBroadcastIfNeeded(id, instOptions); - if (instOptions == INS_OPTS_EVEX_b) - { - assert(UseEvexEncoding()); - id->idSetEvexbContext(); - } #ifdef DEBUG id->idDebugOnlyInfo()->idVarRefOffs = emitVarRefOffs; #endif @@ -8224,11 +8276,11 @@ void emitter::emitIns_SIMD_R_R_C(instruction ins, // op2Reg -- The register of the second operand // void emitter::emitIns_SIMD_R_R_R( - instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg) + instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, insOpts instOptions) { if (UseSimdEncoding()) { - emitIns_R_R_R(ins, attr, targetReg, op1Reg, op2Reg); + emitIns_R_R_R(ins, attr, targetReg, op1Reg, op2Reg, instOptions); } else { @@ -10656,7 +10708,7 @@ void emitter::emitDispInsHex(instrDesc* id, BYTE* code, size_t sz) // void emitter::emitDispEmbBroadcastCount(instrDesc* id) { - if (!id->idIsEvexbContext()) + if (!id->idIsEvexbContextSet()) { return; } @@ -10665,6 +10717,37 @@ void emitter::emitDispEmbBroadcastCount(instrDesc* id) printf(" {1to%d}", vectorSize / baseSize); } +// emitDispEmbRounding: Display the tag where embedded rounding is activated +// +// Arguments: +// id - The instruction descriptor +// +void emitter::emitDispEmbRounding(instrDesc* id) +{ + if (!id->idIsEvexbContextSet()) + { + return; + } + assert(!id->idHasMem()); + unsigned roundingMode = id->idGetEvexbContext(); + if (roundingMode == 1) + { + printf(" {rd-sae}"); + } + else if (roundingMode == 2) + { + printf(" {ru-sae}"); + } + else if (roundingMode == 3) + { + printf(" {rz-sae}"); + } + else + { + unreached(); + } +} + //-------------------------------------------------------------------- // emitDispIns: Dump the given instruction to jitstdout. // @@ -11533,6 +11616,7 @@ void emitter::emitDispIns( printf("%s, ", emitRegName(id->idReg1(), attr)); printf("%s, ", emitRegName(reg2, attr)); printf("%s", emitRegName(reg3, attr)); + emitDispEmbRounding(id); break; } diff --git a/src/coreclr/jit/emitxarch.h b/src/coreclr/jit/emitxarch.h index 82a83dd0f2e53c..05c18bdbff2fae 100644 --- a/src/coreclr/jit/emitxarch.h +++ b/src/coreclr/jit/emitxarch.h @@ -157,6 +157,8 @@ bool IsRedundantCmp(emitAttr size, regNumber reg1, regNumber reg2); bool AreFlagsSetToZeroCmp(regNumber reg, emitAttr opSize, GenCondition cond); bool AreFlagsSetForSignJumpOpt(regNumber reg, emitAttr opSize, GenCondition cond); +insOpts GetEmbRoundingMode(uint8_t mode) const; + bool hasRexPrefix(code_t code) { #ifdef TARGET_AMD64 @@ -335,6 +337,25 @@ code_t AddSimdPrefixIfNeeded(const instrDesc* id, code_t code, emitAttr size) return code; } +//------------------------------------------------------------------------ +// SetEvexBroadcastIfNeeded: set embedded broadcast if needed. +// +// Arguments: +// id - instruction descriptor +// instOptions - emit options +void SetEvexBroadcastIfNeeded(instrDesc* id, insOpts instOptions) +{ + if ((instOptions & INS_OPTS_b_MASK) == INS_OPTS_EVEX_eb_er_rd) + { + assert(UseEvexEncoding()); + id->idSetEvexbContext(instOptions); + } + else + { + assert(instOptions == 0); + } +} + //------------------------------------------------------------------------ // AddSimdPrefixIfNeeded: Add the correct SIMD prefix. // Check if the prefix already exists befpre adding. @@ -627,7 +648,12 @@ void emitIns_R_R_S(instruction ins, int offs, insOpts instOptions = INS_OPTS_NONE); -void emitIns_R_R_R(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2, regNumber reg3); +void emitIns_R_R_R(instruction ins, + emitAttr attr, + regNumber reg1, + regNumber reg2, + regNumber reg3, + insOpts instOptions = INS_OPTS_NONE); void emitIns_R_R_A_I( instruction ins, emitAttr attr, regNumber reg1, regNumber reg2, GenTreeIndir* indir, int ival, insFormat fmt); @@ -738,7 +764,12 @@ void emitIns_SIMD_R_R_C(instruction ins, CORINFO_FIELD_HANDLE fldHnd, int offs, insOpts instOptions = INS_OPTS_NONE); -void emitIns_SIMD_R_R_R(instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg); +void emitIns_SIMD_R_R_R(instruction ins, + emitAttr attr, + regNumber targetReg, + regNumber op1Reg, + regNumber op2Reg, + insOpts instOptions = INS_OPTS_NONE); void emitIns_SIMD_R_R_S(instruction ins, emitAttr attr, regNumber targetReg, @@ -897,7 +928,7 @@ inline bool emitIsUncondJump(instrDesc* jmp) // inline bool HasEmbeddedBroadcast(const instrDesc* id) const { - return id->idIsEvexbContext(); + return id->idIsEvexbContextSet(); } inline bool HasHighSIMDReg(const instrDesc* id) const; diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 0902c00c1650c1..ac4514c6382036 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -556,6 +556,12 @@ enum GenTreeFlags : unsigned int GTF_MDARRLOWERBOUND_NONFAULTING = 0x20000000, // GT_MDARR_LOWER_BOUND -- An MD array lower bound operation that cannot fault. Same as GT_IND_NONFAULTING. + GTF_HW_ER_MASK = 0x30000000, // Bits used by handle types below + GTF_HW_ER_TO_EVEN = 0x00000000, // GT_HWINTRINSIC -- embedded rounding mode: FloatRoundingMode = ToEven (Default) "{rn-sae}" + GTF_HW_ER_TO_NEGATIVEINFINITY = 0x10000000, // GT_HWINTRINSIC -- embedded rounding mode: FloatRoundingMode = ToNegativeInfinity "{rd-sae}" + GTF_HW_ER_TO_POSITIVEINFINITY = 0x20000000, // GT_HWINTRINSIC -- embedded rounding mode: FloatRoundingMode = ToPositiveInfinity "{ru-sae}" + GTF_HW_ER_TO_ZERO = 0x30000000, // GT_HWINTRINSIC -- embedded rounding mode: FloatRoundingMode = ToZero "{rz-sae}" + }; inline constexpr GenTreeFlags operator ~(GenTreeFlags a) @@ -2225,6 +2231,43 @@ struct GenTree return (gtOper == GT_CNS_INT) ? (gtFlags & GTF_ICON_HDL_MASK) : GTF_EMPTY; } +#ifdef FEATURE_HW_INTRINSICS + + void ClearEmbRoundingMode() + { + assert(gtOper == GT_HWINTRINSIC); + gtFlags &= ~GTF_HW_ER_MASK; + } + // Set GenTreeFlags on HardwareIntrinsic node to specify the FloatRoundingMode. + // mode can be one of the values from System.Runtime.Intrinsics.X86.FloatRoundingMode. + void SetEmbRoundingMode(uint8_t mode) + { + assert(gtOper == GT_HWINTRINSIC); + ClearEmbRoundingMode(); + switch (mode) + { + case 0x09: + gtFlags |= GTF_HW_ER_TO_NEGATIVEINFINITY; + break; + case 0x0A: + gtFlags |= GTF_HW_ER_TO_POSITIVEINFINITY; + break; + case 0x0B: + gtFlags |= GTF_HW_ER_TO_ZERO; + break; + default: + break; + } + } + + uint8_t GetEmbRoundingMode() + { + assert(gtOper == GT_HWINTRINSIC); + return (uint8_t)((gtFlags & GTF_HW_ER_MASK) >> 28); + } + +#endif // FEATURE_HW_INTRINSICS + // Mark this node as no longer being a handle; clear its GTF_ICON_*_HDL bits. void ClearIconHandleMask() { @@ -6310,6 +6353,7 @@ struct GenTreeJitIntrinsic : public GenTreeMultiOp }; #ifdef FEATURE_HW_INTRINSICS + struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic { GenTreeHWIntrinsic(var_types type, diff --git a/src/coreclr/jit/hwintrinsic.h b/src/coreclr/jit/hwintrinsic.h index d3dc426133d36d..dcd5c86129b74d 100644 --- a/src/coreclr/jit/hwintrinsic.h +++ b/src/coreclr/jit/hwintrinsic.h @@ -201,8 +201,11 @@ enum HWIntrinsicFlag : unsigned int // The intrinsic is a PermuteVar2x intrinsic HW_Flag_PermuteVar2x = 0x4000000, - // The intrinsic is an embedded broadcast compatiable intrinsic + // The intrinsic is an embedded broadcast compatible intrinsic HW_Flag_EmbBroadcastCompatible = 0x8000000, + + // The intrinsic is an embedded rounding compatible intrinsic + HW_Flag_EmbRoundingCompatible = 0x10000000 #endif // TARGET_XARCH }; @@ -587,6 +590,27 @@ struct HWIntrinsicInfo HWIntrinsicFlag flags = lookupFlags(id); return (flags & HW_Flag_EmbBroadcastCompatible) != 0; } + + static bool IsEmbRoundingCompatible(NamedIntrinsic id) + { + HWIntrinsicFlag flags = lookupFlags(id); + return (flags & HW_Flag_EmbRoundingCompatible) != 0; + } + + static size_t EmbRoundingArgPos(NamedIntrinsic id) + { + // This helper function returns the expected position, + // where the embedded rounding control argument should be. + assert(IsEmbRoundingCompatible(id)); + switch (id) + { + case NI_AVX512F_Add: + return 3; + + default: + unreached(); + } + } #endif // TARGET_XARCH static bool IsMaybeCommutative(NamedIntrinsic id) diff --git a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp index accf1fc62552d5..00342a2d820f35 100644 --- a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp @@ -333,6 +333,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node) emit->emitIns_R_R(ins, simdSize, op1Reg, op2Reg); } } + else if (HWIntrinsicInfo::IsEmbRoundingCompatible(intrinsicId) && !op3->IsCnsIntOrI()) + { + auto emitSwCase = [&](int8_t i) { genHWIntrinsic_R_R_RM(node, ins, simdSize, i); }; + regNumber baseReg = node->ExtractTempReg(); + regNumber offsReg = node->GetSingleTempReg(); + genHWIntrinsicJumpTableFallback(intrinsicId, op3Reg, baseReg, offsReg, emitSwCase); + } else { switch (intrinsicId) @@ -708,6 +715,31 @@ void CodeGen::genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, e genHWIntrinsic_R_R_RM(node, ins, attr, targetReg, op1Reg, op2); } +//------------------------------------------------------------------------ +// genHWIntrinsic_R_R_RM: Generates the code for a hardware intrinsic node that takes a register operand, a +// register/memory operand, and that returns a value in register +// +// Arguments: +// node - The hardware intrinsic node +// ins - The instruction being generated +// attr - The emit attribute for the instruction being generated +// ival - a "fake" immediate to indicate the rounding mode +// +void CodeGen::genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival) +{ + regNumber targetReg = node->GetRegNum(); + GenTree* op1 = node->Op(1); + GenTree* op2 = node->Op(2); + regNumber op1Reg = op1->GetRegNum(); + + assert(targetReg != REG_NA); + assert(op1Reg != REG_NA); + + node->SetEmbRoundingMode((uint8_t)ival); + + genHWIntrinsic_R_R_RM(node, ins, attr, targetReg, op1Reg, op2); +} + //------------------------------------------------------------------------ // genHWIntrinsic_R_R_RM: Generates the code for a hardware intrinsic node that takes a register operand, a // register/memory operand, and that returns a value in register @@ -733,6 +765,33 @@ void CodeGen::genHWIntrinsic_R_R_RM( } bool isRMW = node->isRMWHWIntrinsic(compiler); + + if (node->GetEmbRoundingMode() != 0) + { + // As embedded rounding only appies in R_R_R case, we can skip other checks for different paths. + OperandDesc op2Desc = genOperandDesc(op2); + assert(op2Desc.GetKind() == OperandKind::Reg); + regNumber op2Reg = op2Desc.GetReg(); + + if ((op1Reg != targetReg) && (op2Reg == targetReg) && isRMW) + { + // We have "reg2 = reg1 op reg2" where "reg1 != reg2" on a RMW instruction. + // + // For non-commutative instructions, we should have ensured that op2 was marked + // delay free in order to prevent it from getting assigned the same register + // as target. However, for commutative instructions, we can just swap the operands + // in order to have "reg2 = reg2 op reg1" which will end up producing the right code. + + op2Reg = op1Reg; + op1Reg = targetReg; + } + + uint8_t mode = node->GetEmbRoundingMode(); + insOpts instOptions = GetEmitter()->GetEmbRoundingMode(mode); + GetEmitter()->emitIns_SIMD_R_R_R(ins, attr, targetReg, op1Reg, op2Reg, instOptions); + return; + } + inst_RV_RV_TT(ins, attr, targetReg, op1Reg, op2, isRMW); } diff --git a/src/coreclr/jit/hwintrinsiclistxarch.h b/src/coreclr/jit/hwintrinsiclistxarch.h index 8f6ed5f07c0d2a..893c9d011cf4b8 100644 --- a/src/coreclr/jit/hwintrinsiclistxarch.h +++ b/src/coreclr/jit/hwintrinsiclistxarch.h @@ -829,7 +829,7 @@ HARDWARE_INTRINSIC(AVX2, Xor, // *************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************** // AVX512F Intrinsics HARDWARE_INTRINSIC(AVX512F, Abs, 64, 1, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_pabsd, INS_invalid, INS_vpabsq, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_BaseTypeFromFirstArg) -HARDWARE_INTRINSIC(AVX512F, Add, 64, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_paddd, INS_paddd, INS_paddq, INS_paddq, INS_addps, INS_addpd}, HW_Category_SimpleSIMD, HW_Flag_Commutative|HW_Flag_EmbBroadcastCompatible) +HARDWARE_INTRINSIC(AVX512F, Add, 64, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_paddd, INS_paddd, INS_paddq, INS_paddq, INS_addps, INS_addpd}, HW_Category_SimpleSIMD, HW_Flag_Commutative|HW_Flag_EmbBroadcastCompatible|HW_Flag_EmbRoundingCompatible) HARDWARE_INTRINSIC(AVX512F, AlignRight32, 64, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_valignd, INS_valignd, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM) HARDWARE_INTRINSIC(AVX512F, AlignRight64, 64, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_valignq, INS_valignq, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM) HARDWARE_INTRINSIC(AVX512F, And, 64, 2, true, {INS_pand, INS_pand, INS_pand, INS_pand, INS_pand, INS_pand, INS_vpandq, INS_vpandq, INS_andps, INS_andpd}, HW_Category_SimpleSIMD, HW_Flag_Commutative|HW_Flag_EmbBroadcastCompatible) diff --git a/src/coreclr/jit/hwintrinsicxarch.cpp b/src/coreclr/jit/hwintrinsicxarch.cpp index 97cb490052bb00..bfa71ddc491a54 100644 --- a/src/coreclr/jit/hwintrinsicxarch.cpp +++ b/src/coreclr/jit/hwintrinsicxarch.cpp @@ -285,6 +285,16 @@ CORINFO_InstructionSet HWIntrinsicInfo::lookupIsa(const char* className, const c // int HWIntrinsicInfo::lookupImmUpperBound(NamedIntrinsic id) { + if (HWIntrinsicInfo::IsEmbRoundingCompatible(id)) + { + // The only case this branch should be hit is that JIT is generating a jump table fallback when the + // FloatRoundingMode is not a compile-time constant. + // Although the expected FloatRoundingMode values are 8, 9, 10, 11, but in the generated jump table, results for + // entries within [0, 11] are all calculated, + // Any unexpected value, say [0, 7] should be blocked by the managed code. + return 11; + } + assert(HWIntrinsicInfo::lookupCategory(id) == HW_Category_IMM); switch (id) diff --git a/src/coreclr/jit/instr.cpp b/src/coreclr/jit/instr.cpp index e03c68df7c10fe..473e38df0a8706 100644 --- a/src/coreclr/jit/instr.cpp +++ b/src/coreclr/jit/instr.cpp @@ -1239,7 +1239,7 @@ void CodeGen::inst_RV_RV_TT( bool IsEmbBroadcast = CodeGenInterface::IsEmbeddedBroadcastEnabled(ins, op2); if (IsEmbBroadcast) { - instOptions = INS_OPTS_EVEX_b; + instOptions = INS_OPTS_EVEX_eb_er_rd; if (emitter::IsBitwiseInstruction(ins) && varTypeIsLong(op2->AsHWIntrinsic()->GetSimdBaseType())) { switch (ins) @@ -1306,7 +1306,7 @@ void CodeGen::inst_RV_RV_TT( op1Reg = targetReg; } - emit->emitIns_SIMD_R_R_R(ins, size, targetReg, op1Reg, op2Reg); + emit->emitIns_SIMD_R_R_R(ins, size, targetReg, op1Reg, op2Reg, instOptions); } break; diff --git a/src/coreclr/jit/instr.h b/src/coreclr/jit/instr.h index 5fd9dd456d65cd..e9d8b461bf96d8 100644 --- a/src/coreclr/jit/instr.h +++ b/src/coreclr/jit/instr.h @@ -203,9 +203,15 @@ enum insFlags : uint64_t enum insOpts: unsigned { - INS_OPTS_NONE, + INS_OPTS_NONE = 0, + + INS_OPTS_EVEX_eb_er_rd = 1, // Embedded Broadcast or Round down + + INS_OPTS_EVEX_er_ru = 2, // Round up + + INS_OPTS_EVEX_er_rz = 3, // Round towards zero - INS_OPTS_EVEX_b + INS_OPTS_b_MASK = (INS_OPTS_EVEX_eb_er_rd | INS_OPTS_EVEX_er_ru | INS_OPTS_EVEX_er_rz), // mask for Evex.b related features. }; #elif defined(TARGET_ARM) || defined(TARGET_ARM64) || defined(TARGET_LOONGARCH64) || defined(TARGET_RISCV64) diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index 31e2a55e1b4862..0ff7658c2470fe 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -1067,6 +1067,68 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node) NamedIntrinsic intrinsicId = node->GetHWIntrinsicId(); + if (HWIntrinsicInfo::IsEmbRoundingCompatible(intrinsicId)) + { + size_t numArgs = node->GetOperandCount(); + size_t expectedArgNum = HWIntrinsicInfo::EmbRoundingArgPos(intrinsicId); + if (numArgs == expectedArgNum) + { + CorInfoType simdBaseJitType = node->GetSimdBaseJitType(); + uint32_t simdSize = node->GetSimdSize(); + var_types simdType = node->TypeGet(); + GenTree* lastOp = node->Op(numArgs); + + if (lastOp->IsCnsIntOrI()) + { + uint8_t mode = static_cast(lastOp->AsIntCon()->IconValue()); + + GenTreeHWIntrinsic* embRoundingNode; + switch (numArgs) + { + case 3: + embRoundingNode = comp->gtNewSimdHWIntrinsicNode(simdType, node->Op(1), node->Op(2), + intrinsicId, simdBaseJitType, simdSize); + break; + case 2: + embRoundingNode = comp->gtNewSimdHWIntrinsicNode(simdType, node->Op(1), intrinsicId, + simdBaseJitType, simdSize); + break; + + default: + unreached(); + } + + embRoundingNode->SetEmbRoundingMode(mode); + BlockRange().InsertAfter(lastOp, embRoundingNode); + LIR::Use use; + if (BlockRange().TryGetUse(node, &use)) + { + use.ReplaceWith(embRoundingNode); + } + else + { + embRoundingNode->SetUnusedValue(); + } + BlockRange().Remove(node); + BlockRange().Remove(lastOp); + node = embRoundingNode; + if (mode != 0x08) + { + // As embedded rounding can only work under register-to-register form, we can skip contain check at + // this point. + return node->gtNext; + } + } + else + { + // If the control byte is not constant, generate a jump table fallback when emitting the code. + assert(!lastOp->IsCnsIntOrI()); + node->SetEmbRoundingMode(0x08); + return node->gtNext; + } + } + } + switch (intrinsicId) { case NI_Vector128_ConditionalSelect: diff --git a/src/coreclr/jit/lsraxarch.cpp b/src/coreclr/jit/lsraxarch.cpp index 534c260a9a62cf..9912598f19553c 100644 --- a/src/coreclr/jit/lsraxarch.cpp +++ b/src/coreclr/jit/lsraxarch.cpp @@ -2142,6 +2142,13 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou } } + if (HWIntrinsicInfo::IsEmbRoundingCompatible(intrinsicId) && + numArgs == HWIntrinsicInfo::EmbRoundingArgPos(intrinsicId) && !lastOp->IsCnsIntOrI()) + { + buildInternalIntRegisterDefForNode(intrinsicTree); + buildInternalIntRegisterDefForNode(intrinsicTree); + } + // Determine whether this is an RMW operation where op2+ must be marked delayFree so that it // is not allocated the same register as the target. bool isRMW = intrinsicTree->isRMWHWIntrinsic(compiler); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.PlatformNotSupported.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.PlatformNotSupported.cs index 86a472f71802a1..5e6bc6e2023ec4 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.PlatformNotSupported.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.PlatformNotSupported.cs @@ -1338,6 +1338,11 @@ internal X64() { } /// public static Vector512 Add(Vector512 left, Vector512 right) { throw new PlatformNotSupportedException(); } /// + /// __m512d _mm512_add_pd (__m512d a, __m512d b) + /// VADDPD zmm1 {k1}{z}, zmm2, zmm3/m512/m64bcst{er} + /// + public static Vector512 Add(Vector512 left, Vector512 right, [ConstantExpected(Max = FloatRoundingMode.ToZero)] FloatRoundingMode mode) { throw new PlatformNotSupportedException(); } + /// /// __m512 _mm512_add_ps (__m512 a, __m512 b) /// VADDPS zmm1 {k1}{z}, zmm2, zmm3/m512/m32bcst{er} /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.cs index 534120a3351561..fa0ebeaa816230 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Avx512F.cs @@ -1339,6 +1339,11 @@ internal X64() { } /// public static Vector512 Add(Vector512 left, Vector512 right) => Add(left, right); /// + /// __m512d _mm512_add_pd (__m512d a, __m512d b) + /// VADDPD zmm1 {k1}{z}, zmm2, zmm3/m512/m64bcst{er} + /// + public static Vector512 Add(Vector512 left, Vector512 right, [ConstantExpected(Max = FloatRoundingMode.ToZero)] FloatRoundingMode mode) => Add(left, right, mode); + /// /// __m512 _mm512_add_ps (__m512 a, __m512 b) /// VADDPS zmm1 {k1}{z}, zmm2, zmm3/m512/m32bcst{er} /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Enums.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Enums.cs index 527635bb58e992..ef5471d6ba43ca 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Enums.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/X86/Enums.cs @@ -165,4 +165,24 @@ public enum FloatComparisonMode : byte /// UnorderedTrueSignaling = 31, } + + public enum FloatRoundingMode : byte + { + /// + /// _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC + /// + ToEven = 0x08, + /// + /// _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC + /// + ToNegativeInfinity = 0x09, + /// + /// _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC + /// + ToPositiveInfinity = 0x0A, + /// + /// _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC + /// + ToZero = 0x0B, + } } diff --git a/src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs b/src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs index 00318b5d15d389..b3db125c2c7ece 100644 --- a/src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs +++ b/src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs @@ -5192,6 +5192,7 @@ internal Avx512F() { } public static System.Runtime.Intrinsics.Vector512 Abs(System.Runtime.Intrinsics.Vector512 value) { throw null; } public static System.Runtime.Intrinsics.Vector512 Abs(System.Runtime.Intrinsics.Vector512 value) { throw null; } public static System.Runtime.Intrinsics.Vector512 Add(System.Runtime.Intrinsics.Vector512 left, System.Runtime.Intrinsics.Vector512 right) { throw null; } + public static System.Runtime.Intrinsics.Vector512 Add(System.Runtime.Intrinsics.Vector512 left, System.Runtime.Intrinsics.Vector512 right, [System.Diagnostics.CodeAnalysis.ConstantExpected(Max = System.Runtime.Intrinsics.X86.FloatRoundingMode.ToZero)] System.Runtime.Intrinsics.X86.FloatRoundingMode mode) { throw null; } public static System.Runtime.Intrinsics.Vector512 Add(System.Runtime.Intrinsics.Vector512 left, System.Runtime.Intrinsics.Vector512 right) { throw null; } public static System.Runtime.Intrinsics.Vector512 Add(System.Runtime.Intrinsics.Vector512 left, System.Runtime.Intrinsics.Vector512 right) { throw null; } public static System.Runtime.Intrinsics.Vector512 Add(System.Runtime.Intrinsics.Vector512 left, System.Runtime.Intrinsics.Vector512 right) { throw null; } @@ -6026,6 +6027,13 @@ public enum FloatComparisonMode : byte OrderedGreaterThanNonSignaling = (byte)30, UnorderedTrueSignaling = (byte)31, } + public enum FloatRoundingMode : byte + { + ToEven = 0x08, + ToNegativeInfinity = 0x09, + ToPositiveInfinity = 0x0A, + ToZero = 0x0B, + } [System.CLSCompliantAttribute(false)] public abstract partial class Fma : System.Runtime.Intrinsics.X86.Avx { diff --git a/src/tests/Common/GenerateHWIntrinsicTests/GenerateHWIntrinsicTests_X86.cs b/src/tests/Common/GenerateHWIntrinsicTests/GenerateHWIntrinsicTests_X86.cs index 6343e12513b7c5..6e284032934299 100644 --- a/src/tests/Common/GenerateHWIntrinsicTests/GenerateHWIntrinsicTests_X86.cs +++ b/src/tests/Common/GenerateHWIntrinsicTests/GenerateHWIntrinsicTests_X86.cs @@ -1519,6 +1519,9 @@ ("SimpleBinOpTest.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Xor", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "UInt16", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "UInt16", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "UInt16", ["LargestVectorSize"] = "64", ["NextValueOp1"] = "TestLibrary.Generator.GetUInt16()", ["NextValueOp2"] = "TestLibrary.Generator.GetUInt16()", ["ValidateFirstResult"] = "(ushort)(left[0] ^ right[0]) != result[0]", ["ValidateRemainingResults"] = "(ushort)(left[i] ^ right[i]) != result[i]"}), ("SimpleBinOpTest.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Xor", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "UInt32", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "UInt32", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "UInt32", ["LargestVectorSize"] = "64", ["NextValueOp1"] = "TestLibrary.Generator.GetUInt32()", ["NextValueOp2"] = "TestLibrary.Generator.GetUInt32()", ["ValidateFirstResult"] = "(uint)(left[0] ^ right[0]) != result[0]", ["ValidateRemainingResults"] = "(uint)(left[i] ^ right[i]) != result[i]"}), ("SimpleBinOpTest.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Xor", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "UInt64", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "UInt64", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "UInt64", ["LargestVectorSize"] = "64", ["NextValueOp1"] = "TestLibrary.Generator.GetUInt64()", ["NextValueOp2"] = "TestLibrary.Generator.GetUInt64()", ["ValidateFirstResult"] = "(ulong)(left[0] ^ right[0]) != result[0]", ["ValidateRemainingResults"] = "(ulong)(left[i] ^ right[i]) != result[i]"}), + ("SimpleBinOpEmbRounding.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Add", ["RoundingMode"] = "ToNegativeInfinity", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "Double", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "Double", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "Double", ["LargestVectorSize"] = "64", ["CastingMethod"] = "DoubleToUInt64Bits", ["FixedInput1"] = "0.05", ["FixedInput2"] = "0.45"}), + ("SimpleBinOpEmbRounding.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Add", ["RoundingMode"] = "ToPositiveInfinity", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "Double", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "Double", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "Double", ["LargestVectorSize"] = "64", ["CastingMethod"] = "DoubleToUInt64Bits", ["FixedInput1"] = "0.05", ["FixedInput2"] = "0.45"}), + ("SimpleBinOpEmbRounding.template", new Dictionary { ["Isa"] = "Avx512F", ["LoadIsa"] = "Avx512F", ["Method"] = "Add", ["RoundingMode"] = "ToZero", ["RetVectorType"] = "Vector512", ["RetBaseType"] = "Double", ["Op1VectorType"] = "Vector512", ["Op1BaseType"] = "Double", ["Op2VectorType"] = "Vector512", ["Op2BaseType"] = "Double", ["LargestVectorSize"] = "64", ["CastingMethod"] = "DoubleToUInt64Bits", ["FixedInput1"] = "0.05", ["FixedInput2"] = "0.45"}), }; (string templateFileName, Dictionary templateData)[] Avx512F_ScalarUpperInputs = new [] @@ -2631,6 +2634,12 @@ void ProcessInput(StreamWriter testListFile, string groupName, (string templateF testName += ".Tuple3Op"; suffix += "Tuple3Op"; } + else if (input.templateFileName == "SimpleBinOpEmbRounding.template") + { + testName += ".EmbeddedRounding"; + testName += $".{input.templateData["RoundingMode"]}"; + suffix += "EmbeddedRounding"; + } var fileName = Path.Combine(outputDirectory, $"{testName}.cs"); diff --git a/src/tests/JIT/HardwareIntrinsics/X86/Shared/SimpleBinOpEmbRounding.template b/src/tests/JIT/HardwareIntrinsics/X86/Shared/SimpleBinOpEmbRounding.template new file mode 100644 index 00000000000000..da78721c596d01 --- /dev/null +++ b/src/tests/JIT/HardwareIntrinsics/X86/Shared/SimpleBinOpEmbRounding.template @@ -0,0 +1,315 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +/****************************************************************************** + * This file is auto-generated from a template file by the GenerateTests.csx * + * script in tests\src\JIT\HardwareIntrinsics\X86\Shared. In order to make * + * changes, please update the corresponding template and run according to the * + * directions listed in the file. * + ******************************************************************************/ + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Collections.Generic; +using System.Runtime.Intrinsics.X86; +using Xunit; + +namespace JIT.HardwareIntrinsics.X86 +{ + public static partial class Program + { + [Fact] + public static void {Method}{RetBaseType}{RoundingMode}() + { + var test = new BinaryOpTest__{Method}{RetBaseType}{RoundingMode}(); + + if (test.IsSupported) + { + // Validates basic functionality works, using Unsafe.Read + test.RunBasicScenario_UnsafeRead(); + + if ({LoadIsa}.IsSupported) + { + // Validates basic functionality works, using Load + test.RunBasicScenario_Load(); + + // Validates basic functionality works, using LoadAligned + test.RunBasicScenario_LoadAligned(); + } + + // Validates calling via reflection works, using Unsafe.Read + test.RunReflectionScenario_UnsafeRead(); + + // Validates passing a local works, using Unsafe.Read + test.RunLclVarScenario_UnsafeRead(); + + // Validates passing an instance member of a class works + test.RunClassFldScenario(); + + // Validates passing the field of a local struct works + test.RunStructLclFldScenario(); + + // Validates passing an instance member of a struct works + test.RunStructFldScenario(); + } + else + { + // Validates we throw on unsupported hardware + test.RunUnsupportedScenario(); + } + + if (!test.Succeeded) + { + throw new Exception("One or more scenarios did not complete as expected."); + } + } + } + + public sealed unsafe class BinaryOpTest__{Method}{RetBaseType}{RoundingMode} + { + private struct TestStruct + { + public {Op1VectorType}<{Op1BaseType}> _fld1; + public {Op2VectorType}<{Op2BaseType}> _fld2; + + public static TestStruct Create() + { + var testStruct = new TestStruct(); + + for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = ({Op1BaseType}){FixedInput1}; } + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op1VectorType}<{Op1BaseType}>, byte>(ref testStruct._fld1), ref Unsafe.As<{Op1BaseType}, byte>(ref _data1[0]), (uint)Unsafe.SizeOf<{Op1VectorType}<{Op1BaseType}>>()); + for (var i = 0; i < Op2ElementCount; i++) { _data2[i] = ({Op2BaseType}){FixedInput2}; } + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op2VectorType}<{Op2BaseType}>, byte>(ref testStruct._fld2), ref Unsafe.As<{Op2BaseType}, byte>(ref _data2[0]), (uint)Unsafe.SizeOf<{Op2VectorType}<{Op2BaseType}>>()); + + return testStruct; + } + + public void RunStructFldScenario(BinaryOpTest__{Method}{RetBaseType}{RoundingMode} testClass) + { + var result = {Isa}.{Method}(_fld1, _fld2, FloatRoundingMode.{RoundingMode}); + + Unsafe.Write(testClass._dataTable.outArrayPtr, result); + testClass.ValidateResult(_fld1, _fld2, testClass._dataTable.outArrayPtr); + } + } + + private static readonly int LargestVectorSize = {LargestVectorSize}; + + private static readonly int Op1ElementCount = Unsafe.SizeOf<{Op1VectorType}<{Op1BaseType}>>() / sizeof({Op1BaseType}); + private static readonly int Op2ElementCount = Unsafe.SizeOf<{Op2VectorType}<{Op2BaseType}>>() / sizeof({Op2BaseType}); + private static readonly int RetElementCount = Unsafe.SizeOf<{RetVectorType}<{RetBaseType}>>() / sizeof({RetBaseType}); + + private static {Op1BaseType}[] _data1 = new {Op1BaseType}[Op1ElementCount]; + private static {Op2BaseType}[] _data2 = new {Op2BaseType}[Op2ElementCount]; + + private {Op1VectorType}<{Op1BaseType}> _fld1; + private {Op2VectorType}<{Op2BaseType}> _fld2; + + private SimpleBinaryOpTest__DataTable<{RetBaseType}, {Op1BaseType}, {Op2BaseType}> _dataTable; + + public BinaryOpTest__{Method}{RetBaseType}{RoundingMode}() + { + Succeeded = true; + + for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = ({Op1BaseType}){FixedInput1}; } + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op1VectorType}<{Op1BaseType}>, byte>(ref _fld1), ref Unsafe.As<{Op1BaseType}, byte>(ref _data1[0]), (uint)Unsafe.SizeOf<{Op1VectorType}<{Op1BaseType}>>()); + for (var i = 0; i < Op2ElementCount; i++) { _data2[i] = ({Op2BaseType}){FixedInput2}; } + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op2VectorType}<{Op2BaseType}>, byte>(ref _fld2), ref Unsafe.As<{Op2BaseType}, byte>(ref _data2[0]), (uint)Unsafe.SizeOf<{Op2VectorType}<{Op2BaseType}>>()); + + for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = ({Op1BaseType}){FixedInput1}; } + for (var i = 0; i < Op2ElementCount; i++) { _data2[i] = ({Op2BaseType}){FixedInput2}; } + _dataTable = new SimpleBinaryOpTest__DataTable<{RetBaseType}, {Op1BaseType}, {Op2BaseType}>(_data1, _data2, new {RetBaseType}[RetElementCount], LargestVectorSize); + } + + public bool IsSupported => {Isa}.IsSupported; + + public bool Succeeded { get; set; } + + public void RunBasicScenario_UnsafeRead() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunBasicScenario_UnsafeRead)); + + var result = {Isa}.{Method}( + Unsafe.Read<{Op1VectorType}<{Op1BaseType}>>(_dataTable.inArray1Ptr), + Unsafe.Read<{Op2VectorType}<{Op2BaseType}>>(_dataTable.inArray2Ptr), + FloatRoundingMode.{RoundingMode} + ); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(_dataTable.inArray1Ptr, _dataTable.inArray2Ptr, _dataTable.outArrayPtr); + } + + public void RunBasicScenario_Load() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunBasicScenario_Load)); + + var result = {Isa}.{Method}( + {LoadIsa}.Load{Op1VectorType}(({Op1BaseType}*)(_dataTable.inArray1Ptr)), + {LoadIsa}.Load{Op2VectorType}(({Op2BaseType}*)(_dataTable.inArray2Ptr)), + FloatRoundingMode.{RoundingMode} + ); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(_dataTable.inArray1Ptr, _dataTable.inArray2Ptr, _dataTable.outArrayPtr); + } + + public void RunBasicScenario_LoadAligned() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunBasicScenario_LoadAligned)); + + var result = {Isa}.{Method}( + {LoadIsa}.LoadAligned{Op1VectorType}(({Op1BaseType}*)(_dataTable.inArray1Ptr)), + {LoadIsa}.LoadAligned{Op2VectorType}(({Op2BaseType}*)(_dataTable.inArray2Ptr)), + FloatRoundingMode.{RoundingMode} + ); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(_dataTable.inArray1Ptr, _dataTable.inArray2Ptr, _dataTable.outArrayPtr); + } + + public void RunReflectionScenario_UnsafeRead() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunReflectionScenario_UnsafeRead)); + + var result = typeof({Isa}).GetMethod(nameof({Isa}.{Method}), new Type[] { typeof({Op1VectorType}<{Op1BaseType}>), typeof({Op2VectorType}<{Op2BaseType}>) , typeof(FloatRoundingMode)}) + .Invoke(null, new object[] { + Unsafe.Read<{Op1VectorType}<{Op1BaseType}>>(_dataTable.inArray1Ptr), + Unsafe.Read<{Op2VectorType}<{Op2BaseType}>>(_dataTable.inArray2Ptr), + FloatRoundingMode.{RoundingMode} + }); + + Unsafe.Write(_dataTable.outArrayPtr, ({RetVectorType}<{RetBaseType}>)(result)); + ValidateResult(_dataTable.inArray1Ptr, _dataTable.inArray2Ptr, _dataTable.outArrayPtr); + } + + public void RunLclVarScenario_UnsafeRead() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunLclVarScenario_UnsafeRead)); + + var op1 = Unsafe.Read<{Op1VectorType}<{Op1BaseType}>>(_dataTable.inArray1Ptr); + var op2 = Unsafe.Read<{Op2VectorType}<{Op2BaseType}>>(_dataTable.inArray2Ptr); + var result = {Isa}.{Method}(op1, op2, FloatRoundingMode.{RoundingMode}); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(op1, op2, _dataTable.outArrayPtr); + } + + public void RunClassFldScenario() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunClassFldScenario)); + + var result = {Isa}.{Method}(_fld1, _fld2, FloatRoundingMode.{RoundingMode}); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(_fld1, _fld2, _dataTable.outArrayPtr); + } + + public void RunStructLclFldScenario() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunStructLclFldScenario)); + + var test = TestStruct.Create(); + var result = {Isa}.{Method}(test._fld1, test._fld2, FloatRoundingMode.{RoundingMode}); + + Unsafe.Write(_dataTable.outArrayPtr, result); + ValidateResult(test._fld1, test._fld2, _dataTable.outArrayPtr); + } + + public void RunStructFldScenario() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunStructFldScenario)); + + var test = TestStruct.Create(); + test.RunStructFldScenario(this); + } + + public void RunUnsupportedScenario() + { + TestLibrary.TestFramework.BeginScenario(nameof(RunUnsupportedScenario)); + + bool succeeded = false; + + try + { + RunBasicScenario_UnsafeRead(); + } + catch (PlatformNotSupportedException) + { + succeeded = true; + } + + if (!succeeded) + { + Succeeded = false; + } + } + + private void ValidateResult({Op1VectorType}<{Op1BaseType}> op1, {Op2VectorType}<{Op2BaseType}> op2, void* result, [CallerMemberName] string method = "") + { + {Op1BaseType}[] inArray1 = new {Op1BaseType}[Op1ElementCount]; + {Op2BaseType}[] inArray2 = new {Op2BaseType}[Op2ElementCount]; + {RetBaseType}[] outArray = new {RetBaseType}[RetElementCount]; + + Unsafe.WriteUnaligned(ref Unsafe.As<{Op1BaseType}, byte>(ref inArray1[0]), op1); + Unsafe.WriteUnaligned(ref Unsafe.As<{Op2BaseType}, byte>(ref inArray2[0]), op2); + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{RetBaseType}, byte>(ref outArray[0]), ref Unsafe.AsRef(result), (uint)Unsafe.SizeOf<{RetVectorType}<{RetBaseType}>>()); + + ValidateResult(inArray1, inArray2, outArray, method); + } + + private void ValidateResult(void* op1, void* op2, void* result, [CallerMemberName] string method = "") + { + {Op1BaseType}[] inArray1 = new {Op1BaseType}[Op1ElementCount]; + {Op2BaseType}[] inArray2 = new {Op2BaseType}[Op2ElementCount]; + {RetBaseType}[] outArray = new {RetBaseType}[RetElementCount]; + + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op1BaseType}, byte>(ref inArray1[0]), ref Unsafe.AsRef(op1), (uint)Unsafe.SizeOf<{Op1VectorType}<{Op1BaseType}>>()); + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{Op2BaseType}, byte>(ref inArray2[0]), ref Unsafe.AsRef(op2), (uint)Unsafe.SizeOf<{Op2VectorType}<{Op2BaseType}>>()); + Unsafe.CopyBlockUnaligned(ref Unsafe.As<{RetBaseType}, byte>(ref outArray[0]), ref Unsafe.AsRef(result), (uint)Unsafe.SizeOf<{RetVectorType}<{RetBaseType}>>()); + + ValidateResult(inArray1, inArray2, outArray, method); + } + + private void ValidateResult({Op1BaseType}[] left, {Op2BaseType}[] right, {RetBaseType}[] result, [CallerMemberName] string method = "") + { + bool succeeded = true; + + for (int i = 0; i < result.Length; i++) + { + ulong[] answerTable = binaryEmbRoundingAnswerTable[("{RetBaseType}", "{Method}", "{RoundingMode}")]; + + if (BitConverter.{CastingMethod}(result[i]) != answerTable[i]) + { + succeeded = false; + Console.WriteLine("Avx512 {Method} Embedded rounding failed on {RetBaseType} with {RoundingMode}:"); + foreach (var item in result) + { + Console.Write(item + ", "); + } + Console.WriteLine(); + Assert.Fail(""); + } + } + + if (!succeeded) + { + TestLibrary.TestFramework.LogInformation($"{nameof({Isa})}.{nameof({Isa}.{Method})}<{RetBaseType}>({Op1VectorType}<{Op1BaseType}>, {Op2VectorType}<{Op2BaseType}>): {method} failed:"); + TestLibrary.TestFramework.LogInformation($" left: ({string.Join(", ", left)})"); + TestLibrary.TestFramework.LogInformation($" right: ({string.Join(", ", right)})"); + TestLibrary.TestFramework.LogInformation($" result: ({string.Join(", ", result)})"); + TestLibrary.TestFramework.LogInformation(string.Empty); + + Succeeded = false; + } + } + + private static Dictionary<(string, string, string), ulong[]> binaryEmbRoundingAnswerTable = new Dictionary<(string, string, string), ulong[]> + { + {("Double", "Add", "ToNegativeInfinity"), new ulong[] {0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000}}, + {("Double", "Add", "ToPositiveInfinity"), new ulong[] {0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001, 0x3fe0000000000001}}, + {("Double", "Add", "ToZero"), new ulong[] {0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000, 0x3fe0000000000000}}, + }; + } +}