@@ -197,6 +197,24 @@ struct VectorLayout {
197
197
uint64_t SplitSize = 0 ;
198
198
};
199
199
200
+ static bool isStructOfMatchingFixedVectors (Type *Ty) {
201
+ if (!isa<StructType>(Ty))
202
+ return false ;
203
+ unsigned StructSize = Ty->getNumContainedTypes ();
204
+ if (StructSize < 1 )
205
+ return false ;
206
+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (0 ));
207
+ if (!VecTy)
208
+ return false ;
209
+ unsigned VecSize = VecTy->getNumElements ();
210
+ for (unsigned I = 1 ; I < StructSize; I++) {
211
+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (I));
212
+ if (!VecTy || VecSize != VecTy->getNumElements ())
213
+ return false ;
214
+ }
215
+ return true ;
216
+ }
217
+
200
218
// / Concatenate the given fragments to a single vector value of the type
201
219
// / described in @p VS.
202
220
static Value *concatenate (IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +294,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276
294
bool visitBitCastInst (BitCastInst &BCI);
277
295
bool visitInsertElementInst (InsertElementInst &IEI);
278
296
bool visitExtractElementInst (ExtractElementInst &EEI);
297
+ bool visitExtractValueInst (ExtractValueInst &EVI);
279
298
bool visitShuffleVectorInst (ShuffleVectorInst &SVI);
280
299
bool visitPHINode (PHINode &PHI);
281
300
bool visitLoadInst (LoadInst &LI);
@@ -667,14 +686,26 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
667
686
bool ScalarizerVisitor::isTriviallyScalarizable (Intrinsic::ID ID) {
668
687
if (isTriviallyVectorizable (ID))
669
688
return true ;
689
+ // TODO: Move frexp to isTriviallyVectorizable.
690
+ // https://github.com/llvm/llvm-project/issues/112408
691
+ switch (ID) {
692
+ case Intrinsic::frexp :
693
+ return true ;
694
+ }
670
695
return Intrinsic::isTargetIntrinsic (ID) &&
671
696
TTI->isTargetIntrinsicTriviallyScalarizable (ID);
672
697
}
673
698
674
699
// / If a call to a vector typed intrinsic function, split into a scalar call per
675
700
// / element if possible for the intrinsic.
676
701
bool ScalarizerVisitor::splitCall (CallInst &CI) {
677
- std::optional<VectorSplit> VS = getVectorSplit (CI.getType ());
702
+ Type *CallType = CI.getType ();
703
+ bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors (CallType);
704
+ std::optional<VectorSplit> VS;
705
+ if (AreAllVectorsOfMatchingSize)
706
+ VS = getVectorSplit (CallType->getContainedType (0 ));
707
+ else
708
+ VS = getVectorSplit (CallType);
678
709
if (!VS)
679
710
return false ;
680
711
@@ -699,6 +730,23 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
699
730
if (isVectorIntrinsicWithOverloadTypeAtArg (ID, -1 ))
700
731
Tys.push_back (VS->SplitTy );
701
732
733
+ if (AreAllVectorsOfMatchingSize) {
734
+ for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
735
+ std::optional<VectorSplit> CurrVS =
736
+ getVectorSplit (cast<FixedVectorType>(CallType->getContainedType (I)));
737
+ // This case does not seem to happen, but it is possible for
738
+ // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
739
+ // is not returned and we will bailout of handling this call.
740
+ // The secondary bailout case is if NumPacked does not match.
741
+ // This can happen if ScalarizeMinBits is not set to the default.
742
+ // This means with certain ScalarizeMinBits intrinsics like frexp
743
+ // will only scalarize when the struct elements have the same bitness.
744
+ if (!CurrVS || CurrVS->NumPacked != VS->NumPacked )
745
+ return false ;
746
+ if (isVectorIntrinsicWithStructReturnOverloadAtField (ID, I))
747
+ Tys.push_back (CurrVS->SplitTy );
748
+ }
749
+ }
702
750
// Assumes that any vector type has the same number of elements as the return
703
751
// vector type, which is true for all current intrinsics.
704
752
for (unsigned I = 0 ; I != NumArgs; ++I) {
@@ -1030,6 +1078,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1030
1078
return true ;
1031
1079
}
1032
1080
1081
+ bool ScalarizerVisitor::visitExtractValueInst (ExtractValueInst &EVI) {
1082
+ Value *Op = EVI.getOperand (0 );
1083
+ Type *OpTy = Op->getType ();
1084
+ ValueVector Res;
1085
+ if (!isStructOfMatchingFixedVectors (OpTy))
1086
+ return false ;
1087
+ Type *VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
1088
+ std::optional<VectorSplit> VS = getVectorSplit (VecType);
1089
+ if (!VS)
1090
+ return false ;
1091
+ IRBuilder<> Builder (&EVI);
1092
+ Scatterer Op0 = scatter (&EVI, Op, *VS);
1093
+ assert (!EVI.getIndices ().empty () && " Make sure an index exists" );
1094
+ // Note for our use case we only care about the top level index.
1095
+ unsigned Index = EVI.getIndices ()[0 ];
1096
+ for (unsigned OpIdx = 0 ; OpIdx < Op0.size (); ++OpIdx) {
1097
+ Value *ResElem = Builder.CreateExtractValue (
1098
+ Op0[OpIdx], Index, EVI.getName () + " .elem" + Twine (Index));
1099
+ Res.push_back (ResElem);
1100
+ }
1101
+
1102
+ gather (&EVI, Res, *VS);
1103
+ return true ;
1104
+ }
1105
+
1033
1106
bool ScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
1034
1107
std::optional<VectorSplit> VS = getVectorSplit (EEI.getOperand (0 )->getType ());
1035
1108
if (!VS)
@@ -1209,6 +1282,35 @@ bool ScalarizerVisitor::finish() {
1209
1282
Res = concatenate (Builder, CV, VS, Op->getName ());
1210
1283
1211
1284
Res->takeName (Op);
1285
+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType ())) {
1286
+ BasicBlock *BB = Op->getParent ();
1287
+ IRBuilder<> Builder (Op);
1288
+ if (isa<PHINode>(Op))
1289
+ Builder.SetInsertPoint (BB, BB->getFirstInsertionPt ());
1290
+
1291
+ // Iterate over each element in the struct
1292
+ unsigned NumOfStructElements = Ty->getNumElements ();
1293
+ SmallVector<ValueVector, 4 > ElemCV (NumOfStructElements);
1294
+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1295
+ for (auto *CVelem : CV) {
1296
+ Value *Elem = Builder.CreateExtractValue (
1297
+ CVelem, I, Op->getName () + " .elem" + Twine (I));
1298
+ ElemCV[I].push_back (Elem);
1299
+ }
1300
+ }
1301
+ Res = PoisonValue::get (Ty);
1302
+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1303
+ Type *ElemTy = Ty->getElementType (I);
1304
+ assert (isa<FixedVectorType>(ElemTy) &&
1305
+ " Only Structs of all FixedVectorType supported" );
1306
+ VectorSplit VS = *getVectorSplit (ElemTy);
1307
+ assert (VS.NumFragments == CV.size ());
1308
+
1309
+ Value *ConcatenatedVector =
1310
+ concatenate (Builder, ElemCV[I], VS, Op->getName ());
1311
+ Res = Builder.CreateInsertValue (Res, ConcatenatedVector, I,
1312
+ Op->getName () + " .insert" );
1313
+ }
1212
1314
} else {
1213
1315
assert (CV.size () == 1 && Op->getType () == CV[0 ]->getType ());
1214
1316
Res = CV[0 ];
0 commit comments