@@ -67,6 +67,9 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
67
67
static std::optional<APInt> extractConstantBits (const Constant *C) {
68
68
unsigned NumBits = C->getType ()->getPrimitiveSizeInBits ();
69
69
70
+ if (auto *CUndef = dyn_cast<UndefValue>(C))
71
+ return APInt::getZero (NumBits);
72
+
70
73
if (auto *CInt = dyn_cast<ConstantInt>(C))
71
74
return CInt->getValue ();
72
75
@@ -80,6 +83,18 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
80
83
return APInt::getSplat (NumBits, *Bits);
81
84
}
82
85
}
86
+
87
+ APInt Bits = APInt::getZero (NumBits);
88
+ for (unsigned I = 0 , E = CV->getNumOperands (); I != E; ++I) {
89
+ Constant *Elt = CV->getOperand (I);
90
+ std::optional<APInt> SubBits = extractConstantBits (Elt);
91
+ if (!SubBits)
92
+ return std::nullopt ;
93
+ assert (NumBits == (E * SubBits->getBitWidth ()) &&
94
+ " Illegal vector element size" );
95
+ Bits.insertBits (*SubBits, I * SubBits->getBitWidth ());
96
+ }
97
+ return Bits;
83
98
}
84
99
85
100
if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
@@ -223,6 +238,33 @@ static Constant *rebuildSplatableConstant(const Constant *C,
223
238
return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
224
239
}
225
240
241
+ static Constant *rebuildZeroUpperConstant (const Constant *C,
242
+ unsigned ScalarBitWidth) {
243
+ Type *Ty = C->getType ();
244
+ Type *SclTy = Ty->getScalarType ();
245
+ unsigned NumBits = Ty->getPrimitiveSizeInBits ();
246
+ unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
247
+ LLVMContext &Ctx = C->getContext ();
248
+
249
+ if (NumBits > ScalarBitWidth) {
250
+ // Determine if the upper bits are all zero.
251
+ if (std::optional<APInt> Bits = extractConstantBits (C)) {
252
+ if (Bits->countLeadingZeros () >= (NumBits - ScalarBitWidth)) {
253
+ // If the original constant was made of smaller elements, try to retain
254
+ // those types.
255
+ if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0 )
256
+ return rebuildConstant (Ctx, SclTy, *Bits, NumSclBits);
257
+
258
+ // Fallback to raw integer bits.
259
+ APInt RawBits = Bits->zextOrTrunc (ScalarBitWidth);
260
+ return ConstantInt::get (Ctx, RawBits);
261
+ }
262
+ }
263
+ }
264
+
265
+ return nullptr ;
266
+ }
267
+
226
268
bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
227
269
MachineBasicBlock &MBB,
228
270
MachineInstr &MI) {
@@ -263,6 +305,34 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
263
305
return false ;
264
306
};
265
307
308
+ auto ConvertToZeroUpper = [&](unsigned OpUpper64, unsigned OpUpper32) {
309
+ unsigned OperandNo = 1 ;
310
+ assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
311
+ " Unexpected number of operands!" );
312
+
313
+ if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
314
+ // Attempt to detect a suitable splat from increasing splat widths.
315
+ std::pair<unsigned , unsigned > ZeroUppers[] = {
316
+ {32 , OpUpper32},
317
+ {64 , OpUpper64},
318
+ };
319
+ for (auto [BitWidth, OpUpper] : ZeroUppers) {
320
+ if (OpUpper) {
321
+ // Construct a suitable splat constant and adjust the MI to
322
+ // use the new constant pool entry.
323
+ if (Constant *NewCst = rebuildZeroUpperConstant (C, BitWidth)) {
324
+ unsigned NewCPI =
325
+ CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
326
+ MI.setDesc (TII->get (OpUpper));
327
+ MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
328
+ return true ;
329
+ }
330
+ }
331
+ }
332
+ }
333
+ return false ;
334
+ };
335
+
266
336
// Attempt to convert full width vector loads into broadcast loads.
267
337
switch (Opc) {
268
338
/* FP Loads */
@@ -271,12 +341,13 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
271
341
case X86::MOVUPDrm:
272
342
case X86::MOVUPSrm:
273
343
// TODO: SSE3 MOVDDUP Handling
274
- return false ;
344
+ return ConvertToZeroUpper (X86::MOVSDrm, X86::MOVSSrm) ;
275
345
case X86::VMOVAPDrm:
276
346
case X86::VMOVAPSrm:
277
347
case X86::VMOVUPDrm:
278
348
case X86::VMOVUPSrm:
279
- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
349
+ return ConvertToZeroUpper (X86::VMOVSDrm, X86::VMOVSSrm) ||
350
+ ConvertToBroadcast (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
280
351
1 );
281
352
case X86::VMOVAPDYrm:
282
353
case X86::VMOVAPSYrm:
@@ -288,7 +359,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
288
359
case X86::VMOVAPSZ128rm:
289
360
case X86::VMOVUPDZ128rm:
290
361
case X86::VMOVUPSZ128rm:
291
- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPZ128rm,
362
+ return ConvertToZeroUpper (X86::VMOVSDZrm, X86::VMOVSSZrm) ||
363
+ ConvertToBroadcast (0 , 0 , X86::VMOVDDUPZ128rm,
292
364
X86::VBROADCASTSSZ128rm, 0 , 0 , 1 );
293
365
case X86::VMOVAPDZ256rm:
294
366
case X86::VMOVAPSZ256rm:
@@ -305,13 +377,17 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
305
377
X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 ,
306
378
1 );
307
379
/* Integer Loads */
380
+ case X86::MOVDQArm:
381
+ case X86::MOVDQUrm:
382
+ return ConvertToZeroUpper (X86::MOVQI2PQIrm, X86::MOVDI2PDIrm);
308
383
case X86::VMOVDQArm:
309
384
case X86::VMOVDQUrm:
310
- return ConvertToBroadcast (
311
- 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
312
- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
313
- HasAVX2 ? X86::VPBROADCASTWrm : 0 , HasAVX2 ? X86::VPBROADCASTBrm : 0 ,
314
- 1 );
385
+ return ConvertToZeroUpper (X86::VMOVQI2PQIrm, X86::VMOVDI2PDIrm) ||
386
+ ConvertToBroadcast (
387
+ 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
388
+ HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
389
+ HasAVX2 ? X86::VPBROADCASTWrm : 0 ,
390
+ HasAVX2 ? X86::VPBROADCASTBrm : 0 , 1 );
315
391
case X86::VMOVDQAYrm:
316
392
case X86::VMOVDQUYrm:
317
393
return ConvertToBroadcast (
@@ -324,7 +400,8 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
324
400
case X86::VMOVDQA64Z128rm:
325
401
case X86::VMOVDQU32Z128rm:
326
402
case X86::VMOVDQU64Z128rm:
327
- return ConvertToBroadcast (0 , 0 , X86::VPBROADCASTQZ128rm,
403
+ return ConvertToZeroUpper (X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm) ||
404
+ ConvertToBroadcast (0 , 0 , X86::VPBROADCASTQZ128rm,
328
405
X86::VPBROADCASTDZ128rm,
329
406
HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
330
407
HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 );
0 commit comments