@@ -951,11 +951,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
951
951
952
952
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
953
953
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
954
- static SDValue combineShiftToAVG (SDValue Op, SelectionDAG &DAG,
954
+ static SDValue combineShiftToAVG (SDValue Op,
955
+ TargetLowering::TargetLoweringOpt &TLO,
955
956
const TargetLowering &TLI,
956
957
const APInt &DemandedBits,
957
- const APInt &DemandedElts,
958
- unsigned Depth) {
958
+ const APInt &DemandedElts, unsigned Depth) {
959
959
assert ((Op.getOpcode () == ISD::SRL || Op.getOpcode () == ISD::SRA) &&
960
960
" SRL or SRA node is required here!" );
961
961
// Is the right shift using an immediate value of 1?
@@ -1006,6 +1006,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1006
1006
// If the shift is unsigned (srl):
1007
1007
// - Needs >= 1 zero bit for both operands.
1008
1008
// - Needs 1 demanded bit zero and >= 2 sign bits.
1009
+ SelectionDAG &DAG = TLO.DAG ;
1009
1010
unsigned ShiftOpc = Op.getOpcode ();
1010
1011
bool IsSigned = false ;
1011
1012
unsigned KnownBits;
@@ -1061,10 +1062,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1061
1062
EVT NVT = EVT::getIntegerVT (*DAG.getContext (), llvm::bit_ceil (MinWidth));
1062
1063
if (VT.isVector ())
1063
1064
NVT = EVT::getVectorVT (*DAG.getContext (), NVT, VT.getVectorElementCount ());
1064
- if (!TLI.isOperationLegalOrCustom (AVGOpc, NVT)) {
1065
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, NVT)) {
1065
1066
// If we could not transform, and (both) adds are nuw/nsw, we can use the
1066
1067
// larger type size to do the transform.
1067
- if (!TLI.isOperationLegalOrCustom (AVGOpc, VT))
1068
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, VT))
1068
1069
return SDValue ();
1069
1070
if (DAG.willNotOverflowAdd (IsSigned, Add.getOperand (0 ),
1070
1071
Add.getOperand (1 )) &&
@@ -2015,7 +2016,7 @@ bool TargetLowering::SimplifyDemandedBits(
2015
2016
}
2016
2017
2017
2018
// Try to match AVG patterns (after shift simplification).
2018
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
2019
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2019
2020
DemandedElts, Depth + 1 ))
2020
2021
return TLO.CombineTo (Op, AVG);
2021
2022
@@ -2127,7 +2128,7 @@ bool TargetLowering::SimplifyDemandedBits(
2127
2128
}
2128
2129
2129
2130
// Try to match AVG patterns (after shift simplification).
2130
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
2131
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2131
2132
DemandedElts, Depth + 1 ))
2132
2133
return TLO.CombineTo (Op, AVG);
2133
2134
@@ -9245,6 +9246,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9245
9246
DAG.getNode (ISD::SUB, dl, VT, RHS, LHS));
9246
9247
}
9247
9248
9249
+ SDValue TargetLowering::expandAVG (SDNode *N, SelectionDAG &DAG) const {
9250
+ SDLoc dl (N);
9251
+ EVT VT = N->getValueType (0 );
9252
+ SDValue LHS = N->getOperand (0 );
9253
+ SDValue RHS = N->getOperand (1 );
9254
+
9255
+ unsigned Opc = N->getOpcode ();
9256
+ bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9257
+ bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9258
+ unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9259
+ assert ((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9260
+ Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9261
+ " Unknown AVG node" );
9262
+
9263
+ // If the operands are already extended, we can add+shift.
9264
+ bool IsExt =
9265
+ (IsSigned && DAG.ComputeNumSignBits (LHS) >= 2 &&
9266
+ DAG.ComputeNumSignBits (RHS) >= 2 ) ||
9267
+ (!IsSigned && DAG.computeKnownBits (LHS).countMinLeadingZeros () >= 1 &&
9268
+ DAG.computeKnownBits (RHS).countMinLeadingZeros () >= 1 );
9269
+ if (IsExt) {
9270
+ SDValue Sum = DAG.getNode (ISD::ADD, dl, VT, LHS, RHS);
9271
+ if (!IsFloor)
9272
+ Sum = DAG.getNode (ISD::ADD, dl, VT, Sum, DAG.getConstant (1 , dl, VT));
9273
+ return DAG.getNode (ShiftOpc, dl, VT, Sum,
9274
+ DAG.getShiftAmountConstant (1 , VT, dl));
9275
+ }
9276
+
9277
+ // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9278
+ // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9279
+ // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9280
+ // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9281
+ unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9282
+ unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9283
+ LHS = DAG.getFreeze (LHS);
9284
+ RHS = DAG.getFreeze (RHS);
9285
+ SDValue Sign = DAG.getNode (SignOpc, dl, VT, LHS, RHS);
9286
+ SDValue Xor = DAG.getNode (ISD::XOR, dl, VT, LHS, RHS);
9287
+ SDValue Shift =
9288
+ DAG.getNode (ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant (1 , VT, dl));
9289
+ return DAG.getNode (SumOpc, dl, VT, Sign, Shift);
9290
+ }
9291
+
9248
9292
SDValue TargetLowering::expandBSWAP (SDNode *N, SelectionDAG &DAG) const {
9249
9293
SDLoc dl (N);
9250
9294
EVT VT = N->getValueType (0 );
0 commit comments