Skip to content

Commit 4d39083

Browse files
committed
Arm64 SVE: Support scalable constant vectors and masks
Adds support to GenTreeVecCon and GenTreeMskCon for constants with unknown sizes. Instead of having a blob of data, the constant is represented as being one of either: a repeated value, an sequence with start and step values, or a value in the first lane and the rest zeroed. To handle this the base type is also required. As this new structure is slightly bigger than a simd16, the simd_t typedef is pushed up to simd32 sized. For vector constants, a vector is scalable because if it is of TYP_SIMD. For mask constants, the type is always TYP_MASK. However on Arm64, masks are only used by SVE. Therefore to tell if a mask is scalable then JitUseScalableVectorT is checked. The IsAllBitsSet() on mask constants is updated to include a base type. A mask that is all set for TYP_LONG will not be all set for TYP_BYTE, and instead will be 100010001000... Given two scalable constants it may not be possible to add them together to produce a third scalable constant. Instead they will remain as two vectors in the IR. To show this implementation is workable, scalable support is added for: Sve.CreateTrueMask*() Sve.CreateFalseMask*() Vector.Create() Vector.CreateScalar() Vector.CreateScalarUnsafe() Vector.CreateSequence() Fixes #125057
1 parent 8703788 commit 4d39083

16 files changed

Lines changed: 1367 additions & 258 deletions

src/coreclr/jit/codegenarm64.cpp

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,14 +2325,15 @@ void CodeGen::genSetRegToConst(regNumber targetReg, var_types targetType, GenTre
23252325
GenTreeVecCon* vecCon = tree->AsVecCon();
23262326

23272327
emitter* emit = GetEmitter();
2328-
emitAttr attr = emitTypeSize(targetType);
23292328

23302329
switch (tree->TypeGet())
23312330
{
23322331
case TYP_SIMD8:
23332332
case TYP_SIMD12:
23342333
case TYP_SIMD16:
23352334
{
2335+
emitAttr attr = emitTypeSize(targetType);
2336+
23362337
// We ignore any differences between SIMD12 and SIMD16 here if we can broadcast the value
23372338
// via mvni/movi.
23382339
const bool is8 = tree->TypeIs(TYP_SIMD8);
@@ -2385,6 +2386,104 @@ void CodeGen::genSetRegToConst(regNumber targetReg, var_types targetType, GenTre
23852386
break;
23862387
}
23872388

2389+
case TYP_SIMD:
2390+
{
2391+
simdscalable_t simdVal = vecCon->gtSimdScalableVal;
2392+
insOpts opt = emitter::optGetSveInsOpt(emitTypeSize(simdVal.gtSimdScalableBaseType));
2393+
emitAttr emitSize = (opt == INS_OPTS_SCALABLE_D) ? EA_8BYTE : EA_4BYTE;
2394+
2395+
auto loadConstantHelper = [&](uint64_t constValue) -> regNumber {
2396+
// Get a temp integer register to compute long address.
2397+
regNumber addrReg = internalRegisters.GetSingle(tree);
2398+
2399+
// Store the index to memory
2400+
UNATIVE_OFFSET cnum =
2401+
emit->emitDataConst(&constValue, sizeof(constValue), sizeof(constValue), TYP_LONG);
2402+
CORINFO_FIELD_HANDLE hnd = m_compiler->eeFindJitDataOffs(cnum);
2403+
2404+
// Load the constant
2405+
emit->emitIns_R_C(INS_ldr, emitSize, addrReg, addrReg, hnd, 0);
2406+
2407+
return addrReg;
2408+
};
2409+
2410+
switch (vecCon->gtSimdScalableVal.gtSimdScalableKind)
2411+
{
2412+
case SimdScalableRepeated:
2413+
if (emitter::isValidSimm<8>(simdVal.gtSimdScalableIndex) ||
2414+
emitter::isValidSimm_MultipleOf<8, 256>(simdVal.gtSimdScalableIndex))
2415+
{
2416+
emit->emitInsSve_R_I(INS_sve_dup, EA_SCALABLE, targetReg, simdVal.gtSimdScalableIndex,
2417+
opt);
2418+
}
2419+
else
2420+
{
2421+
regNumber indexReg = loadConstantHelper(simdVal.gtSimdScalableIndex);
2422+
emit->emitInsSve_R_R(INS_sve_dup, emitSize, targetReg, indexReg, opt);
2423+
}
2424+
break;
2425+
2426+
case SimdScalableSequence:
2427+
if (emitter::isValidSimm<5>(simdVal.gtSimdScalableIndex) &&
2428+
emitter::isValidSimm<5>(simdVal.gtSimdScalableStep))
2429+
{
2430+
emit->emitInsSve_R_I_I(INS_sve_index, EA_SCALABLE, targetReg,
2431+
simdVal.gtSimdScalableIndex, simdVal.gtSimdScalableStep, opt);
2432+
}
2433+
else if (emitter::isValidSimm<5>(simdVal.gtSimdScalableIndex))
2434+
{
2435+
regNumber stepReg = loadConstantHelper(simdVal.gtSimdScalableStep);
2436+
emit->emitInsSve_R_R_I(INS_sve_index, emitSize, targetReg, stepReg,
2437+
simdVal.gtSimdScalableIndex, opt, INS_SCALABLE_OPTS_IMM_FIRST);
2438+
}
2439+
else if (emitter::isValidSimm<5>(simdVal.gtSimdScalableStep))
2440+
{
2441+
regNumber indexReg = loadConstantHelper(simdVal.gtSimdScalableIndex);
2442+
emit->emitInsSve_R_R_I(INS_sve_index, emitSize, targetReg, indexReg,
2443+
simdVal.gtSimdScalableStep, opt);
2444+
}
2445+
else
2446+
{
2447+
regNumber indexReg = loadConstantHelper(simdVal.gtSimdScalableIndex);
2448+
regNumber stepReg = loadConstantHelper(simdVal.gtSimdScalableStep);
2449+
emit->emitInsSve_R_R_R(INS_sve_index, emitSize, targetReg, indexReg, stepReg, opt);
2450+
}
2451+
break;
2452+
2453+
case SimdScalableScalar:
2454+
{
2455+
// Clear the entire target register
2456+
emit->emitInsSve_R_I(INS_sve_dup, EA_SCALABLE, targetReg, 0, opt);
2457+
2458+
regNumber indexReg = loadConstantHelper(simdVal.gtSimdScalableIndex);
2459+
2460+
// Use NEON instructions to load the constant (to avoid using predicates)
2461+
2462+
if (varTypeIsIntegral(simdVal.gtSimdScalableBaseType) &&
2463+
emitter::emitIns_valid_imm_for_mov(simdVal.gtSimdScalableIndex, emitSize))
2464+
{
2465+
emit->emitIns_R_I(INS_mov, EA_16BYTE, targetReg, simdVal.gtSimdScalableIndex);
2466+
}
2467+
else if (varTypeIsFloating(simdVal.gtSimdScalableBaseType) &&
2468+
emitter::emitIns_valid_imm_for_fmov(simdVal.gtSimdScalableIndexF64[0]))
2469+
{
2470+
emit->emitIns_R_F(INS_fmov, EA_16BYTE, targetReg, simdVal.gtSimdScalableIndexF64[0]);
2471+
}
2472+
else
2473+
{
2474+
regNumber indexReg = loadConstantHelper(simdVal.gtSimdScalableIndex);
2475+
emit->emitIns_R_R(INS_ins, emitSize, targetReg, indexReg, INS_OPTS_16B);
2476+
}
2477+
break;
2478+
}
2479+
2480+
default:
2481+
unreached();
2482+
break;
2483+
}
2484+
break;
2485+
}
2486+
23882487
default:
23892488
{
23902489
unreached();
@@ -2399,14 +2498,26 @@ void CodeGen::genSetRegToConst(regNumber targetReg, var_types targetType, GenTre
23992498
GenTreeMskCon* mask = tree->AsMskCon();
24002499
emitter* emit = GetEmitter();
24012500

2402-
// Try every type until a match is found
2403-
24042501
if (mask->IsZero())
24052502
{
24062503
emit->emitInsSve_R(INS_sve_pfalse, EA_SCALABLE, targetReg, INS_OPTS_SCALABLE_B);
24072504
break;
24082505
}
24092506

2507+
#if defined(DEBUG)
2508+
if (JitConfig.JitUseScalableVectorT() == 1)
2509+
{
2510+
assert(mask->gtSimdScalableMaskVal.gtSimdMaskScalableIndex == 1);
2511+
2512+
insOpts opt =
2513+
emitter::optGetSveInsOpt(emitTypeSize(mask->gtSimdScalableMaskVal.gtSimdMaskScalableBaseType));
2514+
emit->emitIns_R_PATTERN(INS_sve_ptrue, EA_SCALABLE, targetReg, opt, SVE_PATTERN_ALL);
2515+
break;
2516+
}
2517+
#endif // DEBUG
2518+
2519+
// Fixed length vectors. Try every type until a match is found
2520+
24102521
insOpts opt = INS_OPTS_SCALABLE_B;
24112522
SveMaskPattern pat = EvaluateSimdMaskToPattern<simd16_t>(TYP_BYTE, mask->gtSimdMaskVal);
24122523

src/coreclr/jit/compiler.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3236,10 +3236,21 @@ class Compiler
32363236
#if defined(FEATURE_SIMD)
32373237
GenTreeVecCon* gtNewVconNode(var_types type);
32383238
GenTreeVecCon* gtNewVconNode(var_types type, void* data);
3239+
#if defined(TARGET_ARM64)
3240+
GenTreeVecCon* gtNewSimdVconNode(var_types type, var_types baseType, SimdScalableKind kind, uint64_t index, uint64_t step = 0);
3241+
3242+
inline GenTreeVecCon* gtNewSimdVconNode(var_types type, simdscalable_t* con)
3243+
{
3244+
return gtNewSimdVconNode(type, con->gtSimdScalableBaseType, con->gtSimdScalableKind, con->gtSimdScalableIndex, con->gtSimdScalableStep);
3245+
}
3246+
#endif // TARGET_ARM64
32393247
#endif // FEATURE_SIMD
32403248

32413249
#if defined(FEATURE_MASKED_HW_INTRINSICS)
32423250
GenTreeMskCon* gtNewMskConNode(var_types type);
3251+
#if defined(TARGET_ARM64)
3252+
GenTreeMskCon* gtNewMskConNode(var_types type, var_types baseType, bool index);
3253+
#endif // TARGET_ARM64
32433254
#endif // FEATURE_MASKED_HW_INTRINSICS
32443255

32453256
GenTree* gtNewAllBitsSetConNode(var_types type);
@@ -3348,7 +3359,7 @@ class Compiler
33483359
var_types type, GenTree* op1, var_types simdBaseType, unsigned simdSize);
33493360

33503361
#if defined(TARGET_ARM64)
3351-
GenTree* gtNewSimdAllTrueMaskNode(var_types simdBaseType);
3362+
GenTree* gtNewSimdTrueMaskNode(var_types simdBaseType);
33523363
GenTree* gtNewSimdFalseMaskByteNode();
33533364
#endif
33543365

@@ -3916,7 +3927,7 @@ class Compiler
39163927

39173928
#if defined(FEATURE_HW_INTRINSICS)
39183929
GenTree* gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree);
3919-
GenTreeMskCon* gtFoldExprConvertVecCnsToMask(GenTreeHWIntrinsic* tree, GenTreeVecCon* vecCon);
3930+
GenTree* gtFoldExprConvertVecCnsToMask(GenTreeHWIntrinsic* tree, GenTreeVecCon* vecCon);
39203931
#endif // FEATURE_HW_INTRINSICS
39213932

39223933
// Options to control behavior of gtTryRemoveBoxUpstreamEffects

src/coreclr/jit/compiler.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,14 @@ inline bool genExactlyOneBit(T value)
102102
inline regMaskTP genFindLowestBit(regMaskTP value)
103103
{
104104
#ifdef HAS_MORE_THAN_64_REGISTERS
105-
// If we ever need to use this method for predicate
106-
// registers, then handle it.
107-
assert(value.getHigh() == RBM_NONE);
108-
#endif
105+
if (value.getLow() != RBM_NONE)
106+
{
107+
return regMaskTP(genFindLowestBit(value.getLow()));
108+
}
109+
return regMaskTP(RBM_NONE, genFindLowestBit(value.getHigh()));
110+
#else
109111
return regMaskTP(genFindLowestBit(value.getLow()));
112+
#endif
110113
}
111114

112115
/*****************************************************************************
@@ -117,11 +120,18 @@ inline regMaskTP genFindLowestBit(regMaskTP value)
117120
inline bool genMaxOneBit(regMaskTP value)
118121
{
119122
#ifdef HAS_MORE_THAN_64_REGISTERS
120-
// If we ever need to use this method for predicate
121-
// registers, then handle it.
122-
assert(value.getHigh() == RBM_NONE);
123-
#endif
123+
if (value.getLow() == RBM_NONE)
124+
{
125+
return genMaxOneBit(value.getHigh());
126+
}
127+
if (value.getHigh() == RBM_NONE)
128+
{
129+
return genMaxOneBit(value.getLow());
130+
}
131+
return false;
132+
#else
124133
return genMaxOneBit(value.getLow());
134+
#endif
125135
}
126136

127137
/*****************************************************************************

src/coreclr/jit/emitarm64.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -804,22 +804,6 @@ static bool isValidUimm_MultipleOf(ssize_t value)
804804
return isValidUimm<bits>(value / mod) && (value % mod == 0);
805805
}
806806

807-
// Returns true if 'value' is a legal signed immediate with 'bits' number of bits.
808-
template <const size_t bits>
809-
static bool isValidSimm(ssize_t value)
810-
{
811-
constexpr ssize_t max = 1 << (bits - 1);
812-
return (-max <= value) && (value < max);
813-
}
814-
815-
// Returns true if 'value' is a legal signed multiple of 'mod' immediate with 'bits' number of bits.
816-
template <const size_t bits, const ssize_t mod>
817-
static bool isValidSimm_MultipleOf(ssize_t value)
818-
{
819-
static_assert(mod != 0);
820-
return isValidSimm<bits>(value / mod) && (value % mod == 0);
821-
}
822-
823807
// Returns true if 'imm' is a valid broadcast immediate for some SVE DUP variants
824808
static bool isValidBroadcastImm(ssize_t imm, emitAttr laneSize)
825809
{
@@ -1085,6 +1069,22 @@ static bool canEncodeByteShiftedImm(INT64 imm, emitAttr size, bool allow_MSL, em
10851069
// true if 'immDbl' can be encoded using a 'float immediate', also returns the encoding if wbFPI is non-null
10861070
static bool canEncodeFloatImm8(double immDbl, emitter::floatImm8* wbFPI = nullptr);
10871071

1072+
// Returns true if 'value' is a legal signed immediate with 'bits' number of bits.
1073+
template <const size_t bits>
1074+
static bool isValidSimm(ssize_t value)
1075+
{
1076+
constexpr ssize_t max = 1 << (bits - 1);
1077+
return (-max <= value) && (value < max);
1078+
}
1079+
1080+
// Returns true if 'value' is a legal signed multiple of 'mod' immediate with 'bits' number of bits.
1081+
template <const size_t bits, const ssize_t mod>
1082+
static bool isValidSimm_MultipleOf(ssize_t value)
1083+
{
1084+
static_assert(mod != 0);
1085+
return isValidSimm<bits>(value / mod) && (value % mod == 0);
1086+
}
1087+
10881088
// Returns the number of bits used by the given 'size'.
10891089
inline static unsigned getBitWidth(emitAttr size)
10901090
{

0 commit comments

Comments
 (0)