@@ -947,11 +947,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
947
947
948
948
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
949
949
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
950
- static SDValue combineShiftToAVG (SDValue Op, SelectionDAG &DAG,
950
+ static SDValue combineShiftToAVG (SDValue Op,
951
+ TargetLowering::TargetLoweringOpt &TLO,
951
952
const TargetLowering &TLI,
952
953
const APInt &DemandedBits,
953
- const APInt &DemandedElts,
954
- unsigned Depth) {
954
+ const APInt &DemandedElts, unsigned Depth) {
955
955
assert ((Op.getOpcode () == ISD::SRL || Op.getOpcode () == ISD::SRA) &&
956
956
" SRL or SRA node is required here!" );
957
957
// Is the right shift using an immediate value of 1?
@@ -1002,6 +1002,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1002
1002
// If the shift is unsigned (srl):
1003
1003
// - Needs >= 1 zero bit for both operands.
1004
1004
// - Needs 1 demanded bit zero and >= 2 sign bits.
1005
+ SelectionDAG &DAG = TLO.DAG ;
1005
1006
unsigned ShiftOpc = Op.getOpcode ();
1006
1007
bool IsSigned = false ;
1007
1008
unsigned KnownBits;
@@ -1057,10 +1058,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1057
1058
EVT NVT = EVT::getIntegerVT (*DAG.getContext (), llvm::bit_ceil (MinWidth));
1058
1059
if (VT.isVector ())
1059
1060
NVT = EVT::getVectorVT (*DAG.getContext (), NVT, VT.getVectorElementCount ());
1060
- if (!TLI.isOperationLegalOrCustom (AVGOpc, NVT)) {
1061
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, NVT)) {
1061
1062
// If we could not transform, and (both) adds are nuw/nsw, we can use the
1062
1063
// larger type size to do the transform.
1063
- if (!TLI.isOperationLegalOrCustom (AVGOpc, VT))
1064
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, VT))
1064
1065
return SDValue ();
1065
1066
if (DAG.willNotOverflowAdd (IsSigned, Add.getOperand (0 ),
1066
1067
Add.getOperand (1 )) &&
@@ -1908,7 +1909,7 @@ bool TargetLowering::SimplifyDemandedBits(
1908
1909
EVT ShiftVT = Op1.getValueType ();
1909
1910
1910
1911
// Try to match AVG patterns.
1911
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
1912
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
1912
1913
DemandedElts, Depth + 1 ))
1913
1914
return TLO.CombineTo (Op, AVG);
1914
1915
@@ -2014,7 +2015,7 @@ bool TargetLowering::SimplifyDemandedBits(
2014
2015
return TLO.CombineTo (Op, TLO.DAG .getNode (ISD::SRL, dl, VT, Op0, Op1));
2015
2016
2016
2017
// Try to match AVG patterns.
2017
- if (SDValue AVG = combineShiftToAVG (Op, TLO. DAG , *this , DemandedBits,
2018
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2018
2019
DemandedElts, Depth + 1 ))
2019
2020
return TLO.CombineTo (Op, AVG);
2020
2021
@@ -9200,6 +9201,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9200
9201
DAG.getNode (ISD::SUB, dl, VT, RHS, LHS));
9201
9202
}
9202
9203
9204
+ SDValue TargetLowering::expandAVG (SDNode *N, SelectionDAG &DAG) const {
9205
+ SDLoc dl (N);
9206
+ EVT VT = N->getValueType (0 );
9207
+ SDValue LHS = N->getOperand (0 );
9208
+ SDValue RHS = N->getOperand (1 );
9209
+
9210
+ unsigned Opc = N->getOpcode ();
9211
+ bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9212
+ bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9213
+ unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9214
+ assert ((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9215
+ Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9216
+ " Unknown AVG node" );
9217
+
9218
+ // If the operands are already extended, we can add+shift.
9219
+ bool IsExt =
9220
+ (IsSigned && DAG.ComputeNumSignBits (LHS) >= 2 &&
9221
+ DAG.ComputeNumSignBits (RHS) >= 2 ) ||
9222
+ (!IsSigned && DAG.computeKnownBits (LHS).countMinLeadingZeros () >= 1 &&
9223
+ DAG.computeKnownBits (RHS).countMinLeadingZeros () >= 1 );
9224
+ if (IsExt) {
9225
+ SDValue Sum = DAG.getNode (ISD::ADD, dl, VT, LHS, RHS);
9226
+ if (!IsFloor)
9227
+ Sum = DAG.getNode (ISD::ADD, dl, VT, Sum, DAG.getConstant (1 , dl, VT));
9228
+ return DAG.getNode (ShiftOpc, dl, VT, Sum,
9229
+ DAG.getShiftAmountConstant (1 , VT, dl));
9230
+ }
9231
+
9232
+ // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9233
+ // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9234
+ // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9235
+ // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9236
+ unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9237
+ unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9238
+ LHS = DAG.getFreeze (LHS);
9239
+ RHS = DAG.getFreeze (RHS);
9240
+ SDValue Sign = DAG.getNode (SignOpc, dl, VT, LHS, RHS);
9241
+ SDValue Xor = DAG.getNode (ISD::XOR, dl, VT, LHS, RHS);
9242
+ SDValue Shift =
9243
+ DAG.getNode (ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant (1 , VT, dl));
9244
+ return DAG.getNode (SumOpc, dl, VT, Sign, Shift);
9245
+ }
9246
+
9203
9247
SDValue TargetLowering::expandBSWAP (SDNode *N, SelectionDAG &DAG) const {
9204
9248
SDLoc dl (N);
9205
9249
EVT VT = N->getValueType (0 );
0 commit comments