@@ -951,11 +951,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
951
951
952
952
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
953
953
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
954
- static SDValue combineShiftToAVG (SDValue Op, SelectionDAG &DAG,
954
+ static SDValue combineShiftToAVG (SDValue Op,
955
+ TargetLowering::TargetLoweringOpt &TLO,
955
956
const TargetLowering &TLI,
956
957
const APInt &DemandedBits,
957
- const APInt &DemandedElts,
958
- unsigned Depth) {
958
+ const APInt &DemandedElts, unsigned Depth) {
959
959
assert ((Op.getOpcode () == ISD::SRL || Op.getOpcode () == ISD::SRA) &&
960
960
" SRL or SRA node is required here!" );
961
961
// Is the right shift using an immediate value of 1?
@@ -1006,6 +1006,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1006
1006
// If the shift is unsigned (srl):
1007
1007
// - Needs >= 1 zero bit for both operands.
1008
1008
// - Needs 1 demanded bit zero and >= 2 sign bits.
1009
+ SelectionDAG &DAG = TLO.DAG ;
1009
1010
unsigned ShiftOpc = Op.getOpcode ();
1010
1011
bool IsSigned = false ;
1011
1012
unsigned KnownBits;
@@ -1061,10 +1062,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1061
1062
EVT NVT = EVT::getIntegerVT (*DAG.getContext (), llvm::bit_ceil (MinWidth));
1062
1063
if (VT.isVector ())
1063
1064
NVT = EVT::getVectorVT (*DAG.getContext (), NVT, VT.getVectorElementCount ());
1064
- if (!TLI.isOperationLegalOrCustom (AVGOpc, NVT)) {
1065
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, NVT)) {
1065
1066
// If we could not transform, and (both) adds are nuw/nsw, we can use the
1066
1067
// larger type size to do the transform.
1067
- if (!TLI.isOperationLegalOrCustom (AVGOpc, VT))
1068
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, VT))
1068
1069
return SDValue ();
1069
1070
if (DAG.willNotOverflowAdd (IsSigned, Add.getOperand (0 ),
1070
1071
Add.getOperand (1 )) &&
@@ -2002,7 +2003,7 @@ bool TargetLowering::SimplifyDemandedBits(
2002
2003
}
2003
2004
2004
2005
// Try to match AVG patterns (after shift simplification).
2005
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
2006
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2006
2007
DemandedElts, Depth + 1 ))
2007
2008
return TLO.CombineTo (Op, AVG);
2008
2009
@@ -2113,7 +2114,7 @@ bool TargetLowering::SimplifyDemandedBits(
2113
2114
}
2114
2115
2115
2116
// Try to match AVG patterns (after shift simplification).
2116
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
2117
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2117
2118
DemandedElts, Depth + 1 ))
2118
2119
return TLO.CombineTo (Op, AVG);
2119
2120
@@ -9225,6 +9226,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9225
9226
DAG.getNode (ISD::SUB, dl, VT, RHS, LHS));
9226
9227
}
9227
9228
9229
+ SDValue TargetLowering::expandAVG (SDNode *N, SelectionDAG &DAG) const {
9230
+ SDLoc dl (N);
9231
+ EVT VT = N->getValueType (0 );
9232
+ SDValue LHS = N->getOperand (0 );
9233
+ SDValue RHS = N->getOperand (1 );
9234
+
9235
+ unsigned Opc = N->getOpcode ();
9236
+ bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9237
+ bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9238
+ unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9239
+ assert ((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9240
+ Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9241
+ " Unknown AVG node" );
9242
+
9243
+ // If the operands are already extended, we can add+shift.
9244
+ bool IsExt =
9245
+ (IsSigned && DAG.ComputeNumSignBits (LHS) >= 2 &&
9246
+ DAG.ComputeNumSignBits (RHS) >= 2 ) ||
9247
+ (!IsSigned && DAG.computeKnownBits (LHS).countMinLeadingZeros () >= 1 &&
9248
+ DAG.computeKnownBits (RHS).countMinLeadingZeros () >= 1 );
9249
+ if (IsExt) {
9250
+ SDValue Sum = DAG.getNode (ISD::ADD, dl, VT, LHS, RHS);
9251
+ if (!IsFloor)
9252
+ Sum = DAG.getNode (ISD::ADD, dl, VT, Sum, DAG.getConstant (1 , dl, VT));
9253
+ return DAG.getNode (ShiftOpc, dl, VT, Sum,
9254
+ DAG.getShiftAmountConstant (1 , VT, dl));
9255
+ }
9256
+
9257
+ // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9258
+ // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9259
+ // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9260
+ // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9261
+ unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9262
+ unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9263
+ LHS = DAG.getFreeze (LHS);
9264
+ RHS = DAG.getFreeze (RHS);
9265
+ SDValue Sign = DAG.getNode (SignOpc, dl, VT, LHS, RHS);
9266
+ SDValue Xor = DAG.getNode (ISD::XOR, dl, VT, LHS, RHS);
9267
+ SDValue Shift =
9268
+ DAG.getNode (ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant (1 , VT, dl));
9269
+ return DAG.getNode (SumOpc, dl, VT, Sign, Shift);
9270
+ }
9271
+
9228
9272
SDValue TargetLowering::expandBSWAP (SDNode *N, SelectionDAG &DAG) const {
9229
9273
SDLoc dl (N);
9230
9274
EVT VT = N->getValueType (0 );
0 commit comments