Skip to content
Merged
107 changes: 77 additions & 30 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class VectorCombine {
bool foldInsExtFNeg(Instruction &I);
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitOpOfBitcasts(Instruction &I);
bool foldBitOpOfCastops(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
Expand Down Expand Up @@ -808,48 +808,87 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
return true;
}

bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
Value *LHSSrc, *RHSSrc;
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
m_BitCast(m_Value(RHSSrc)))))
/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
/// Supports: bitcast, trunc, sext, zext
bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
// Check if this is a bitwise logic operation
auto *BinOp = dyn_cast<BinaryOperator>(&I);
if (!BinOp || !BinOp->isBitwiseLogicOp())
return false;

// Get the cast instructions
auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
if (!LHSCast || !RHSCast) {
LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
return false;
}

// Both casts must be the same type
Instruction::CastOps CastOpcode = LHSCast->getOpcode();
if (CastOpcode != RHSCast->getOpcode())
return false;

// Only handle supported cast operations
switch (CastOpcode) {
case Instruction::BitCast:
case Instruction::Trunc:
case Instruction::SExt:
case Instruction::ZExt:
break;
default:
return false;
}

Value *LHSSrc = LHSCast->getOperand(0);
Value *RHSSrc = RHSCast->getOperand(0);

// Source types must match
if (LHSSrc->getType() != RHSSrc->getType())
return false;
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
return false;

// Only handle vector types
// Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;

// Same total bit width
assert(SrcVecTy->getPrimitiveSizeInBits() ==
DstVecTy->getPrimitiveSizeInBits() &&
"Bitcast should preserve total bit width");
if (!SrcVecTy->getScalarType()->isIntegerTy() ||
!DstVecTy->getScalarType()->isIntegerTy())
return false;

// Cost Check :
// OldCost = bitlogic + 2*bitcasts
// NewCost = bitlogic + bitcast
auto *BinOp = cast<BinaryOperator>(&I);
// OldCost = bitlogic + 2*casts
// NewCost = bitlogic + cast

// Calculate specific costs for each cast with instruction context
InstructionCost LHSCastCost =
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None, CostKind, LHSCast);
InstructionCost RHSCastCost =
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None, CostKind, RHSCast);

InstructionCost OldCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
TTI::CastContextHint::None) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
TTI::CastContextHint::None);
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy, CostKind) +
LHSCastCost + RHSCastCost;

// For new cost, we can't provide an instruction (it doesn't exist yet)
InstructionCost GenericCastCost = TTI.getCastInstrCost(
CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);

InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy, CostKind) +
GenericCastCost;

LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
// Account for multi-use casts using specific costs
if (!LHSCast->hasOneUse())
NewCost += LHSCastCost;
if (!RHSCast->hasOneUse())
NewCost += RHSCastCost;

LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
<< " NewCost=" << NewCost << "\n");

if (NewCost > OldCost)
return false;
Expand All @@ -862,8 +901,16 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {

Worklist.pushValue(NewOp);

// Bitcast the result back
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
// Create the cast operation directly to ensure we get a new instruction
Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());

// Preserve cast instruction flags
NewCast->copyIRFlags(LHSCast);
NewCast->andIRFlags(RHSCast);

// Insert the new instruction
Value *Result = Builder.Insert(NewCast);

replaceValue(I, *Result);
return true;
}
Expand Down Expand Up @@ -3773,7 +3820,7 @@ bool VectorCombine::run() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
MadeChange |= foldBitOpOfBitcasts(I);
MadeChange |= foldBitOpOfCastops(I);
break;
default:
MadeChange |= shrinkType(I);
Expand Down
Loading