@@ -947,11 +947,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
947
947
948
948
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
949
949
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
950
- static SDValue combineShiftToAVG (SDValue Op, SelectionDAG &DAG,
950
+ static SDValue combineShiftToAVG (SDValue Op,
951
+ TargetLowering::TargetLoweringOpt &TLO,
951
952
const TargetLowering &TLI,
952
953
const APInt &DemandedBits,
953
- const APInt &DemandedElts,
954
- unsigned Depth) {
954
+ const APInt &DemandedElts, unsigned Depth) {
955
955
assert ((Op.getOpcode () == ISD::SRL || Op.getOpcode () == ISD::SRA) &&
956
956
" SRL or SRA node is required here!" );
957
957
// Is the right shift using an immediate value of 1?
@@ -1002,6 +1002,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1002
1002
// If the shift is unsigned (srl):
1003
1003
// - Needs >= 1 zero bit for both operands.
1004
1004
// - Needs 1 demanded bit zero and >= 2 sign bits.
1005
+ SelectionDAG &DAG = TLO.DAG ;
1005
1006
unsigned ShiftOpc = Op.getOpcode ();
1006
1007
bool IsSigned = false ;
1007
1008
unsigned KnownBits;
@@ -1057,10 +1058,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1057
1058
EVT NVT = EVT::getIntegerVT (*DAG.getContext (), llvm::bit_ceil (MinWidth));
1058
1059
if (VT.isVector ())
1059
1060
NVT = EVT::getVectorVT (*DAG.getContext (), NVT, VT.getVectorElementCount ());
1060
- if (!TLI.isOperationLegalOrCustom (AVGOpc, NVT)) {
1061
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, NVT)) {
1061
1062
// If we could not transform, and (both) adds are nuw/nsw, we can use the
1062
1063
// larger type size to do the transform.
1063
- if (!TLI.isOperationLegalOrCustom (AVGOpc, VT))
1064
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, VT))
1064
1065
return SDValue ();
1065
1066
if (DAG.willNotOverflowAdd (IsSigned, Add.getOperand (0 ),
1066
1067
Add.getOperand (1 )) &&
@@ -1908,11 +1909,6 @@ bool TargetLowering::SimplifyDemandedBits(
1908
1909
SDValue Op1 = Op.getOperand (1 );
1909
1910
EVT ShiftVT = Op1.getValueType ();
1910
1911
1911
- // Try to match AVG patterns.
1912
- if (SDValue AVG = combineShiftToAVG (Op, TLO.DAG , *this , DemandedBits,
1913
- DemandedElts, Depth + 1 ))
1914
- return TLO.CombineTo (Op, AVG);
1915
-
1916
1912
KnownBits KnownSA = TLO.DAG .computeKnownBits (Op1, DemandedElts, Depth + 1 );
1917
1913
if (KnownSA.isConstant () && KnownSA.getConstant ().ult (BitWidth)) {
1918
1914
unsigned ShAmt = KnownSA.getConstant ().getZExtValue ();
@@ -1994,6 +1990,12 @@ bool TargetLowering::SimplifyDemandedBits(
1994
1990
// shift amounts.
1995
1991
Known = TLO.DAG .computeKnownBits (Op, DemandedElts, Depth);
1996
1992
}
1993
+
1994
+ // Try to match AVG patterns (after shift simplification).
1995
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
1996
+ DemandedElts, Depth + 1 ))
1997
+ return TLO.CombineTo (Op, AVG);
1998
+
1997
1999
break ;
1998
2000
}
1999
2001
case ISD::SRA: {
@@ -2015,11 +2017,6 @@ bool TargetLowering::SimplifyDemandedBits(
2015
2017
if (DemandedBits.isOne ())
2016
2018
return TLO.CombineTo (Op, TLO.DAG .getNode (ISD::SRL, dl, VT, Op0, Op1));
2017
2019
2018
- // Try to match AVG patterns.
2019
- if (SDValue AVG = combineShiftToAVG (Op, TLO.DAG , *this , DemandedBits,
2020
- DemandedElts, Depth + 1 ))
2021
- return TLO.CombineTo (Op, AVG);
2022
-
2023
2020
KnownBits KnownSA = TLO.DAG .computeKnownBits (Op1, DemandedElts, Depth + 1 );
2024
2021
if (KnownSA.isConstant () && KnownSA.getConstant ().ult (BitWidth)) {
2025
2022
unsigned ShAmt = KnownSA.getConstant ().getZExtValue ();
@@ -2106,6 +2103,12 @@ bool TargetLowering::SimplifyDemandedBits(
2106
2103
}
2107
2104
}
2108
2105
}
2106
+
2107
+ // Try to match AVG patterns (after shift simplification).
2108
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2109
+ DemandedElts, Depth + 1 ))
2110
+ return TLO.CombineTo (Op, AVG);
2111
+
2109
2112
break ;
2110
2113
}
2111
2114
case ISD::FSHL:
@@ -9203,6 +9206,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9203
9206
DAG.getNode (ISD::SUB, dl, VT, RHS, LHS));
9204
9207
}
9205
9208
9209
+ SDValue TargetLowering::expandAVG (SDNode *N, SelectionDAG &DAG) const {
9210
+ SDLoc dl (N);
9211
+ EVT VT = N->getValueType (0 );
9212
+ SDValue LHS = N->getOperand (0 );
9213
+ SDValue RHS = N->getOperand (1 );
9214
+
9215
+ unsigned Opc = N->getOpcode ();
9216
+ bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9217
+ bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9218
+ unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9219
+ assert ((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9220
+ Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9221
+ " Unknown AVG node" );
9222
+
9223
+ // If the operands are already extended, we can add+shift.
9224
+ bool IsExt =
9225
+ (IsSigned && DAG.ComputeNumSignBits (LHS) >= 2 &&
9226
+ DAG.ComputeNumSignBits (RHS) >= 2 ) ||
9227
+ (!IsSigned && DAG.computeKnownBits (LHS).countMinLeadingZeros () >= 1 &&
9228
+ DAG.computeKnownBits (RHS).countMinLeadingZeros () >= 1 );
9229
+ if (IsExt) {
9230
+ SDValue Sum = DAG.getNode (ISD::ADD, dl, VT, LHS, RHS);
9231
+ if (!IsFloor)
9232
+ Sum = DAG.getNode (ISD::ADD, dl, VT, Sum, DAG.getConstant (1 , dl, VT));
9233
+ return DAG.getNode (ShiftOpc, dl, VT, Sum,
9234
+ DAG.getShiftAmountConstant (1 , VT, dl));
9235
+ }
9236
+
9237
+ // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9238
+ // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9239
+ // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9240
+ // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9241
+ unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9242
+ unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9243
+ LHS = DAG.getFreeze (LHS);
9244
+ RHS = DAG.getFreeze (RHS);
9245
+ SDValue Sign = DAG.getNode (SignOpc, dl, VT, LHS, RHS);
9246
+ SDValue Xor = DAG.getNode (ISD::XOR, dl, VT, LHS, RHS);
9247
+ SDValue Shift =
9248
+ DAG.getNode (ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant (1 , VT, dl));
9249
+ return DAG.getNode (SumOpc, dl, VT, Sign, Shift);
9250
+ }
9251
+
9206
9252
SDValue TargetLowering::expandBSWAP (SDNode *N, SelectionDAG &DAG) const {
9207
9253
SDLoc dl (N);
9208
9254
EVT VT = N->getValueType (0 );
0 commit comments