@@ -947,11 +947,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
947
947
948
948
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
949
949
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
950
- static SDValue combineShiftToAVG (SDValue Op, SelectionDAG &DAG,
950
+ static SDValue combineShiftToAVG (SDValue Op,
951
+ TargetLowering::TargetLoweringOpt &TLO,
951
952
const TargetLowering &TLI,
952
953
const APInt &DemandedBits,
953
- const APInt &DemandedElts,
954
- unsigned Depth) {
954
+ const APInt &DemandedElts, unsigned Depth) {
955
955
assert ((Op.getOpcode () == ISD::SRL || Op.getOpcode () == ISD::SRA) &&
956
956
" SRL or SRA node is required here!" );
957
957
// Is the right shift using an immediate value of 1?
@@ -1002,6 +1002,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1002
1002
// If the shift is unsigned (srl):
1003
1003
// - Needs >= 1 zero bit for both operands.
1004
1004
// - Needs 1 demanded bit zero and >= 2 sign bits.
1005
+ SelectionDAG &DAG = TLO.DAG ;
1005
1006
unsigned ShiftOpc = Op.getOpcode ();
1006
1007
bool IsSigned = false ;
1007
1008
unsigned KnownBits;
@@ -1057,10 +1058,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
1057
1058
EVT NVT = EVT::getIntegerVT (*DAG.getContext (), llvm::bit_ceil (MinWidth));
1058
1059
if (VT.isVector ())
1059
1060
NVT = EVT::getVectorVT (*DAG.getContext (), NVT, VT.getVectorElementCount ());
1060
- if (!TLI.isOperationLegalOrCustom (AVGOpc, NVT)) {
1061
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, NVT)) {
1061
1062
// If we could not transform, and (both) adds are nuw/nsw, we can use the
1062
1063
// larger type size to do the transform.
1063
- if (!TLI.isOperationLegalOrCustom (AVGOpc, VT))
1064
+ if (TLO. LegalOperations () && !TLI.isOperationLegal (AVGOpc, VT))
1064
1065
return SDValue ();
1065
1066
if (DAG.willNotOverflowAdd (IsSigned, Add.getOperand (0 ),
1066
1067
Add.getOperand (1 )) &&
@@ -1907,11 +1908,6 @@ bool TargetLowering::SimplifyDemandedBits(
1907
1908
SDValue Op1 = Op.getOperand (1 );
1908
1909
EVT ShiftVT = Op1.getValueType ();
1909
1910
1910
- // Try to match AVG patterns.
1911
- if (SDValue AVG = combineShiftToAVG (Op, TLO.DAG , *this , DemandedBits,
1912
- DemandedElts, Depth + 1 ))
1913
- return TLO.CombineTo (Op, AVG);
1914
-
1915
1911
if (const APInt *SA =
1916
1912
TLO.DAG .getValidShiftAmountConstant (Op, DemandedElts)) {
1917
1913
unsigned ShAmt = SA->getZExtValue ();
@@ -1992,6 +1988,12 @@ bool TargetLowering::SimplifyDemandedBits(
1992
1988
// shift amounts.
1993
1989
Known = TLO.DAG .computeKnownBits (Op, DemandedElts, Depth);
1994
1990
}
1991
+
1992
+ // Try to match AVG patterns (after shift simplification).
1993
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
1994
+ DemandedElts, Depth + 1 ))
1995
+ return TLO.CombineTo (Op, AVG);
1996
+
1995
1997
break ;
1996
1998
}
1997
1999
case ISD::SRA: {
@@ -2013,11 +2015,6 @@ bool TargetLowering::SimplifyDemandedBits(
2013
2015
if (DemandedBits.isOne ())
2014
2016
return TLO.CombineTo (Op, TLO.DAG .getNode (ISD::SRL, dl, VT, Op0, Op1));
2015
2017
2016
- // Try to match AVG patterns.
2017
- if (SDValue AVG = combineShiftToAVG (Op, TLO.DAG , *this , DemandedBits,
2018
- DemandedElts, Depth + 1 ))
2019
- return TLO.CombineTo (Op, AVG);
2020
-
2021
2018
if (const APInt *SA =
2022
2019
TLO.DAG .getValidShiftAmountConstant (Op, DemandedElts)) {
2023
2020
unsigned ShAmt = SA->getZExtValue ();
@@ -2103,6 +2100,12 @@ bool TargetLowering::SimplifyDemandedBits(
2103
2100
}
2104
2101
}
2105
2102
}
2103
+
2104
+ // Try to match AVG patterns (after shift simplification).
2105
+ if (SDValue AVG = combineShiftToAVG (Op, TLO, *this , DemandedBits,
2106
+ DemandedElts, Depth + 1 ))
2107
+ return TLO.CombineTo (Op, AVG);
2108
+
2106
2109
break ;
2107
2110
}
2108
2111
case ISD::FSHL:
@@ -9200,6 +9203,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9200
9203
DAG.getNode (ISD::SUB, dl, VT, RHS, LHS));
9201
9204
}
9202
9205
9206
+ SDValue TargetLowering::expandAVG (SDNode *N, SelectionDAG &DAG) const {
9207
+ SDLoc dl (N);
9208
+ EVT VT = N->getValueType (0 );
9209
+ SDValue LHS = N->getOperand (0 );
9210
+ SDValue RHS = N->getOperand (1 );
9211
+
9212
+ unsigned Opc = N->getOpcode ();
9213
+ bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9214
+ bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9215
+ unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9216
+ assert ((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9217
+ Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9218
+ " Unknown AVG node" );
9219
+
9220
+ // If the operands are already extended, we can add+shift.
9221
+ bool IsExt =
9222
+ (IsSigned && DAG.ComputeNumSignBits (LHS) >= 2 &&
9223
+ DAG.ComputeNumSignBits (RHS) >= 2 ) ||
9224
+ (!IsSigned && DAG.computeKnownBits (LHS).countMinLeadingZeros () >= 1 &&
9225
+ DAG.computeKnownBits (RHS).countMinLeadingZeros () >= 1 );
9226
+ if (IsExt) {
9227
+ SDValue Sum = DAG.getNode (ISD::ADD, dl, VT, LHS, RHS);
9228
+ if (!IsFloor)
9229
+ Sum = DAG.getNode (ISD::ADD, dl, VT, Sum, DAG.getConstant (1 , dl, VT));
9230
+ return DAG.getNode (ShiftOpc, dl, VT, Sum,
9231
+ DAG.getShiftAmountConstant (1 , VT, dl));
9232
+ }
9233
+
9234
+ // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9235
+ // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9236
+ // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9237
+ // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9238
+ unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9239
+ unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9240
+ LHS = DAG.getFreeze (LHS);
9241
+ RHS = DAG.getFreeze (RHS);
9242
+ SDValue Sign = DAG.getNode (SignOpc, dl, VT, LHS, RHS);
9243
+ SDValue Xor = DAG.getNode (ISD::XOR, dl, VT, LHS, RHS);
9244
+ SDValue Shift =
9245
+ DAG.getNode (ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant (1 , VT, dl));
9246
+ return DAG.getNode (SumOpc, dl, VT, Sign, Shift);
9247
+ }
9248
+
9203
9249
SDValue TargetLowering::expandBSWAP (SDNode *N, SelectionDAG &DAG) const {
9204
9250
SDLoc dl (N);
9205
9251
EVT VT = N->getValueType (0 );
0 commit comments