Skip to content

Commit 18bb175

Browse files
committed
[AArch64] Add costs for LD3/LD4 shuffles.
Similar to #87934, this adds costs to the shuffles in a canonical LD3/LD4 pattern, which are represented in LLVM as deinterleaving-shuffle(load). This likely has less effect at the moment than the ST3/ST4 costs as instcombine will perform certain transforms without considering the cost.
1 parent 8c0341d commit 18bb175

File tree

6 files changed

+168
-146
lines changed

6 files changed

+168
-146
lines changed

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
13761376

13771377
return TargetTTI->getShuffleCost(
13781378
IsUnary ? TTI::SK_PermuteSingleSrc : TTI::SK_PermuteTwoSrc, VecTy,
1379-
AdjustMask, CostKind, 0, nullptr, {}, Shuffle);
1379+
AdjustMask, CostKind, 0, nullptr, Operands, Shuffle);
13801380
}
13811381

13821382
// Narrowing shuffle - perform shuffle at original wider width and
@@ -1385,7 +1385,7 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
13851385

13861386
InstructionCost ShuffleCost = TargetTTI->getShuffleCost(
13871387
IsUnary ? TTI::SK_PermuteSingleSrc : TTI::SK_PermuteTwoSrc,
1388-
VecSrcTy, AdjustMask, CostKind, 0, nullptr, {}, Shuffle);
1388+
VecSrcTy, AdjustMask, CostKind, 0, nullptr, Operands, Shuffle);
13891389

13901390
SmallVector<int, 16> ExtractMask(Mask.size());
13911391
std::iota(ExtractMask.begin(), ExtractMask.end(), 0);

llvm/include/llvm/IR/Instructions.h

+10
Original file line numberDiff line numberDiff line change
@@ -2631,6 +2631,16 @@ class ShuffleVectorInst : public Instruction {
26312631
return isInterleaveMask(Mask, Factor, NumInputElts, StartIndexes);
26322632
}
26332633

2634+
/// Check if the mask is a DE-interleave mask of the given factor
2635+
/// \p Factor like:
2636+
/// <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
2637+
static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
2638+
unsigned &Index);
2639+
static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor) {
2640+
unsigned Unused;
2641+
return isDeInterleaveMaskOfFactor(Mask, Factor, Unused);
2642+
}
2643+
26342644
/// Checks if the shuffle is a bit rotation of the first operand across
26352645
/// multiple subelements, e.g:
26362646
///

llvm/lib/CodeGen/InterleavedAccessPass.cpp

+5-27
Original file line numberDiff line numberDiff line change
@@ -200,28 +200,6 @@ FunctionPass *llvm::createInterleavedAccessPass() {
200200
return new InterleavedAccess();
201201
}
202202

203-
/// Check if the mask is a DE-interleave mask of the given factor
204-
/// \p Factor like:
205-
/// <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
206-
static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
207-
unsigned &Index) {
208-
// Check all potential start indices from 0 to (Factor - 1).
209-
for (Index = 0; Index < Factor; Index++) {
210-
unsigned i = 0;
211-
212-
// Check that elements are in ascending order by Factor. Ignore undef
213-
// elements.
214-
for (; i < Mask.size(); i++)
215-
if (Mask[i] >= 0 && static_cast<unsigned>(Mask[i]) != Index + i * Factor)
216-
break;
217-
218-
if (i == Mask.size())
219-
return true;
220-
}
221-
222-
return false;
223-
}
224-
225203
/// Check if the mask is a DE-interleave mask for an interleaved load.
226204
///
227205
/// E.g. DE-interleave masks (Factor = 2) could be:
@@ -238,7 +216,7 @@ static bool isDeInterleaveMask(ArrayRef<int> Mask, unsigned &Factor,
238216
// Make sure we don't produce a load wider than the input load.
239217
if (Mask.size() * Factor > NumLoadElements)
240218
return false;
241-
if (isDeInterleaveMaskOfFactor(Mask, Factor, Index))
219+
if (ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, Factor, Index))
242220
return true;
243221
}
244222

@@ -333,8 +311,8 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
333311
for (auto *Shuffle : Shuffles) {
334312
if (Shuffle->getType() != VecTy)
335313
return false;
336-
if (!isDeInterleaveMaskOfFactor(Shuffle->getShuffleMask(), Factor,
337-
Index))
314+
if (!ShuffleVectorInst::isDeInterleaveMaskOfFactor(
315+
Shuffle->getShuffleMask(), Factor, Index))
338316
return false;
339317

340318
assert(Shuffle->getShuffleMask().size() <= NumLoadElements);
@@ -343,8 +321,8 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
343321
for (auto *Shuffle : BinOpShuffles) {
344322
if (Shuffle->getType() != VecTy)
345323
return false;
346-
if (!isDeInterleaveMaskOfFactor(Shuffle->getShuffleMask(), Factor,
347-
Index))
324+
if (!ShuffleVectorInst::isDeInterleaveMaskOfFactor(
325+
Shuffle->getShuffleMask(), Factor, Index))
348326
return false;
349327

350328
assert(Shuffle->getShuffleMask().size() <= NumLoadElements);

llvm/lib/IR/Instructions.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -2978,6 +2978,31 @@ bool ShuffleVectorInst::isInterleaveMask(
29782978
return true;
29792979
}
29802980

2981+
/// Check if the mask is a DE-interleave mask of the given factor
2982+
/// \p Factor like:
2983+
/// <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
2984+
bool ShuffleVectorInst::isDeInterleaveMaskOfFactor(ArrayRef<int> Mask,
2985+
unsigned Factor,
2986+
unsigned &Index) {
2987+
// Check all potential start indices from 0 to (Factor - 1).
2988+
for (unsigned Idx = 0; Idx < Factor; Idx++) {
2989+
unsigned I = 0;
2990+
2991+
// Check that elements are in ascending order by Factor. Ignore undef
2992+
// elements.
2993+
for (; I < Mask.size(); I++)
2994+
if (Mask[I] >= 0 && static_cast<unsigned>(Mask[I]) != Idx + I * Factor)
2995+
break;
2996+
2997+
if (I == Mask.size()) {
2998+
Index = Idx;
2999+
return true;
3000+
}
3001+
}
3002+
3003+
return false;
3004+
}
3005+
29813006
/// Try to lower a vector shuffle as a bit rotation.
29823007
///
29833008
/// Look for a repeated rotation pattern in each sub group.

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -3827,9 +3827,18 @@ InstructionCost AArch64TTIImpl::getShuffleCost(
38273827
Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
38283828
Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
38293829

3830+
// Check for LD3/LD4 instructions, which are represented in llvm IR as
3831+
// deinterleaving-shuffle(load). The shuffle cost could potentially be free,
3832+
// but we model it with a cost of LT.first so that LD3/LD4 have a higher
3833+
// cost than just the load.
3834+
if (Args.size() >= 1 && isa<LoadInst>(Args[0]) &&
3835+
(ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 3) ||
3836+
ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 4)))
3837+
return std::max<InstructionCost>(1, LT.first / 4);
3838+
38303839
// Check for ST3/ST4 instructions, which are represented in llvm IR as
38313840
// store(interleaving-shuffle). The shuffle cost could potentially be free,
3832-
// but we model it with a cost of LT.first so that LD3/LD3 have a higher
3841+
// but we model it with a cost of LT.first so that ST3/ST4 have a higher
38333842
// cost than just the store.
38343843
if (CxtI && CxtI->hasOneUse() && isa<StoreInst>(*CxtI->user_begin()) &&
38353844
(ShuffleVectorInst::isInterleaveMask(

0 commit comments

Comments
 (0)