Skip to content

[LV] Add support for partial reductions without a binary op #133922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,9 +1299,21 @@ class TargetTransformInfo {
/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces two vectors to another of 4 times fewer and 4
/// takes an accumulator of type \p AccumType and a second vector operand to
/// be accumulated, whose element count is specified by \p VF. The type of
/// reduction is specified by \p Opcode. The second operand passed to the
/// intrinsic could be the result of an extend, such as sext or zext. In
/// this case \p BinOp is nullopt, \p InputTypeA represents the type being
/// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
/// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
/// Alternatively, the second operand could be the result of a binary
/// operation performed on two extends, i.e.
/// mul(zext i8 %a -> i32, zext i8 %b -> i32).
/// In this case \p BinOp may specify the opcode of the binary operation,
/// \p InputTypeA and \p InputTypeB the types being extended, and
/// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
/// operation that uses a partial reduction is a dot product, which reduces
/// two vectors in binary mul operation to another of 4 times fewer and 4
/// times larger elements.
InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Expand Down
24 changes: 12 additions & 12 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(

// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
OpAExtend == TTI::PR_None)
return Invalid;

if (InputTypeA != InputTypeB)
// We only support multiply binary operations for now, and for muls we
// require the types being extended to be the same.
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
// only if the i8mm or sve/streaming features are available.
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable())))
return Invalid;
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for the interface should probably also be updated, documenting that Opcode and the second the second type are optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, absolutely right! I've tried to amend the documentation for the interface. Please take a look and see if it makes sense.

"Unexpected values for OpBExtend or InputTypeB");

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);
Expand Down Expand Up @@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
} else
return Invalid;

// AArch64 supports lowering mixed extensions to a usdot but only if the
// i8mm or sve/streaming features are available.
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable()))
return Invalid;

if (!BinOp || *BinOp != Instruction::Mul)
return Invalid;

return Cost;
}

Expand Down
85 changes: 57 additions & 28 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.

// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
// Build up a set of partial reduction ops for efficient use checking.
SmallSet<User *, 4> PartialReductionOps;
for (const auto &[PartialRdx, _] : PartialReductionChains)
PartialReductionBinOps.insert(PartialRdx.BinOp);
PartialReductionOps.insert(PartialRdx.ExtendUser);

auto ExtendIsOnlyUsedByPartialReductions =
[&PartialReductionBinOps](Instruction *Extend) {
[&PartialReductionOps](Instruction *Extend) {
return all_of(Extend->users(), [&](const User *U) {
return PartialReductionBinOps.contains(U);
return PartialReductionOps.contains(U);
});
};

Expand All @@ -8782,15 +8782,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
for (auto Pair : PartialReductionChains) {
PartialReductionChain Chain = Pair.first;
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
}
}

bool VPRecipeBuilder::getScaledReductions(
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {

if (!CM.TheLoop->contains(RdxExitInstr))
return false;

Expand Down Expand Up @@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
if (PhiOp != PHI)
return false;

auto *BinOp = dyn_cast<BinaryOperator>(Op);
if (!BinOp || !BinOp->hasOneUse())
return false;

using namespace llvm::PatternMatch;
// Use the side-effect of match to replace BinOp only if the pattern is
// matched, we don't care at this point whether it actually matched.
match(BinOp, m_Neg(m_BinOp(BinOp)));

Value *A, *B;
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
return false;
// If the update is a binary operator, check both of its operands to see if
// they are extends. Otherwise, see if the update comes directly from an
// extend.
Instruction *Exts[2] = {nullptr};
BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
std::optional<unsigned> BinOpc;
Type *ExtOpTypes[2] = {nullptr};

auto CollectExtInfo = [&Exts,
&ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
unsigned I = 0;
for (Value *OpI : Ops) {
Value *ExtOp;
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
return false;
Exts[I] = cast<Instruction>(OpI);
ExtOpTypes[I] = ExtOp->getType();
I++;
}
return true;
};

if (ExtendUser) {
if (!ExtendUser->hasOneUse())
return false;

// Use the side-effect of match to replace BinOp only if the pattern is
// matched, we don't care at this point whether it actually matched.
match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));

Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
SmallVector<Value *> Ops(ExtendUser->operands());
if (!CollectExtInfo(Ops))
return false;

BinOpc = std::make_optional(ExtendUser->getOpcode());
} else if (match(Update, m_Add(m_Value(), m_Value()))) {
// We already know the operands for Update are Op and PhiOp.
SmallVector<Value *> Ops({Op});
if (!CollectExtInfo(Ops))
return false;

ExtendUser = Update;
BinOpc = std::nullopt;
} else
return false;

TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);

PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
: TargetTransformInfo::PR_None;
PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);

unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
A->getType()->getPrimitiveSizeInBits());
ExtOpTypes[0]->getPrimitiveSizeInBits());

if (LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
VF, OpAExtend, OpBExtend,
std::make_optional(BinOp->getOpcode()));
Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
return Cost.isValid();
},
Range)) {
Expand Down
15 changes: 8 additions & 7 deletions llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,23 @@ struct HistogramInfo;
struct VFRange;

/// A chain of instructions that form a partial reduction.
/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
/// accumulator).
/// Designed to match either:
/// reduction_bin_op (extend (A), accumulator), or
/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
struct PartialReductionChain {
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
Instruction *ExtendB, Instruction *BinOp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment above needs updating

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
}
Instruction *ExtendB, Instruction *ExtendUser)
: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
ExtendUser(ExtendUser) {}
/// The top-level binary operation that forms the reduction to a scalar
/// after the loop body.
Instruction *Reduction;
/// The extension of each of the inner binary operation's operands.
Instruction *ExtendA;
Instruction *ExtendB;

/// The binary operation using the extends that is then reduced.
Instruction *BinOp;
/// The user of the extend that is then reduced.
Instruction *ExtendUser;
};

/// Helper class to create VPRecipies from IR instructions.
Expand Down
64 changes: 40 additions & 24 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode = std::nullopt;
VPValue *BinOp = getOperand(0);
// If the input operand is an extend then use the opcode for this recipe.
std::optional<unsigned> Opcode;
VPValue *Op = getOperand(0);
VPRecipeBase *OpR = Op->getDefiningRecipe();

// If the partial reduction is predicated, a select will be operand 0 rather
// than the binary op
// than the extend user.
using namespace llvm::VPlanPatternMatch;
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
BinOp = BinOp->getDefiningRecipe()->getOperand(1);

// If BinOp is a negation, use the side effect of match to assign the actual
// binary operation to BinOp
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();

if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
Opcode = std::make_optional(WidenR->getOpcode());

VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();

auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
: BinOpR->getOperand(0));
auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
: BinOpR->getOperand(1));
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
Op = OpR->getOperand(1);
OpR = Op->getDefiningRecipe();
}

auto GetExtendKind = [](VPRecipeBase *R) {
// The extend could come from outside the plan.
Expand All @@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return TargetTransformInfo::PR_None;
};

Type *InputTypeA, *InputTypeB;
TTI::PartialReductionExtendKind ExtAType, ExtBType;

// The input may come straight from a zext or sext.
if (isa<VPWidenCastRecipe>(OpR)) {
Opcode = std::nullopt;
InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
InputTypeB = nullptr;
ExtAType = GetExtendKind(OpR);
ExtBType = TargetTransformInfo::PR_None;
} else {
// If BinOp is a negation, use the side effect of match to assign the actual
// binary operation to BinOp
match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
OpR = Op->getDefiningRecipe();
Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());

VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();

InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
: OpR->getOperand(0));
InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
: OpR->getOperand(1));
ExtAType = GetExtendKind(ExtAR);
ExtBType = GetExtendKind(ExtBR);
}

auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
PhiType, VF, GetExtendKind(ExtAR),
GetExtendKind(ExtBR), Opcode);
PhiType, VF, ExtAType, ExtBType,
Opcode);
}

void VPPartialReductionRecipe::execute(VPTransformState &State) {
Expand Down
Loading
Loading