-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[DAGCombine] Fix multi-use miscompile in load combine #81492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>; | |
static std::optional<SDByteProvider> | ||
calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, | ||
std::optional<uint64_t> VectorIndex, | ||
SmallPtrSetImpl<SDNode *> &ExtractElements, | ||
unsigned StartingIndex = 0) { | ||
|
||
// Typical i64 by i8 pattern requires recursion up to 8 calls depth | ||
|
@@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, | |
|
||
switch (Op.getOpcode()) { | ||
case ISD::OR: { | ||
auto LHS = | ||
calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex); | ||
auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1, | ||
VectorIndex, ExtractElements); | ||
if (!LHS) | ||
return std::nullopt; | ||
auto RHS = | ||
calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex); | ||
auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1, | ||
VectorIndex, ExtractElements); | ||
if (!RHS) | ||
return std::nullopt; | ||
|
||
|
@@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, | |
return Index < ByteShift | ||
? SDByteProvider::getConstantZero() | ||
: calculateByteProvider(Op->getOperand(0), Index - ByteShift, | ||
Depth + 1, VectorIndex, Index); | ||
Depth + 1, VectorIndex, ExtractElements, | ||
Index); | ||
} | ||
case ISD::ANY_EXTEND: | ||
case ISD::SIGN_EXTEND: | ||
|
@@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, | |
SDByteProvider::getConstantZero()) | ||
: std::nullopt; | ||
return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, | ||
StartingIndex); | ||
ExtractElements, StartingIndex); | ||
} | ||
case ISD::BSWAP: | ||
return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1, | ||
Depth + 1, VectorIndex, StartingIndex); | ||
Depth + 1, VectorIndex, ExtractElements, | ||
StartingIndex); | ||
case ISD::EXTRACT_VECTOR_ELT: { | ||
auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1)); | ||
if (!OffsetOp) | ||
|
@@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, | |
if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex) | ||
return std::nullopt; | ||
|
||
ExtractElements.insert(Op.getNode()); | ||
return calculateByteProvider(Op->getOperand(0), Index, Depth + 1, | ||
VectorIndex, StartingIndex); | ||
VectorIndex, ExtractElements, StartingIndex); | ||
} | ||
case ISD::LOAD: { | ||
auto L = cast<LoadSDNode>(Op.getNode()); | ||
|
@@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { | |
SDValue Chain; | ||
|
||
SmallPtrSet<LoadSDNode *, 8> Loads; | ||
SmallPtrSet<SDNode *, 8> ExtractElements; | ||
std::optional<SDByteProvider> FirstByteProvider; | ||
int64_t FirstOffset = INT64_MAX; | ||
|
||
|
@@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { | |
unsigned ZeroExtendedBytes = 0; | ||
for (int i = ByteWidth - 1; i >= 0; --i) { | ||
auto P = | ||
calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt, | ||
calculateByteProvider(SDValue(N, 0), i, 0, | ||
/*VectorIndex*/ std::nullopt, ExtractElements, | ||
|
||
/*StartingIndex*/ i); | ||
if (!P) | ||
return SDValue(); | ||
|
@@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { | |
if (!Allowed || !Fast) | ||
return SDValue(); | ||
|
||
// calculatebyteProvider() allows multi-use for vector loads. Ensure that | ||
// all uses are in vector element extracts that are part of the pattern. | ||
for (LoadSDNode *L : Loads) | ||
if (L->getMemoryVT().isVector()) | ||
for (auto It = L->use_begin(); It != L->use_end(); ++It) | ||
if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It)) | ||
return SDValue(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we go ahead and call DAG.makeEquivalentMemoryOrdering instead of ReplaceAllUsesOfValueWith below for maximum safety? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gah, I've been looking everywhere for a function that does this, and couldn't find it. After looking at a few other uses of that function, it seems like we actually have quite a few folds that are happy to introduce extra loads to avoid other instructions, so it seems like strictly avoiding multi-use loads in this fold may not be desirable. I've opened #81586 to only use makeEquivalentMemoryOrdering() instead. |
||
SDValue NewLoad = | ||
DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT, | ||
Chain, FirstLoad->getBasePtr(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
byte -> Byte