Skip to content

[DAGCombine] Fix multi-use miscompile in load combine #81492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>;
static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
std::optional<uint64_t> VectorIndex,
SmallPtrSetImpl<SDNode *> &ExtractElements,
unsigned StartingIndex = 0) {

// Typical i64 by i8 pattern requires recursion up to 8 calls depth
Expand All @@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,

switch (Op.getOpcode()) {
case ISD::OR: {
auto LHS =
calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
VectorIndex, ExtractElements);
if (!LHS)
return std::nullopt;
auto RHS =
calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1,
VectorIndex, ExtractElements);
if (!RHS)
return std::nullopt;

Expand All @@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
return Index < ByteShift
? SDByteProvider::getConstantZero()
: calculateByteProvider(Op->getOperand(0), Index - ByteShift,
Depth + 1, VectorIndex, Index);
Depth + 1, VectorIndex, ExtractElements,
Index);
}
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
Expand All @@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
SDByteProvider::getConstantZero())
: std::nullopt;
return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
StartingIndex);
ExtractElements, StartingIndex);
}
case ISD::BSWAP:
return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
Depth + 1, VectorIndex, StartingIndex);
Depth + 1, VectorIndex, ExtractElements,
StartingIndex);
case ISD::EXTRACT_VECTOR_ELT: {
auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
if (!OffsetOp)
Expand All @@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
return std::nullopt;

ExtractElements.insert(Op.getNode());
return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
VectorIndex, StartingIndex);
VectorIndex, ExtractElements, StartingIndex);
}
case ISD::LOAD: {
auto L = cast<LoadSDNode>(Op.getNode());
Expand Down Expand Up @@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
SDValue Chain;

SmallPtrSet<LoadSDNode *, 8> Loads;
SmallPtrSet<SDNode *, 8> ExtractElements;
std::optional<SDByteProvider> FirstByteProvider;
int64_t FirstOffset = INT64_MAX;

Expand All @@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
unsigned ZeroExtendedBytes = 0;
for (int i = ByteWidth - 1; i >= 0; --i) {
auto P =
calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
calculateByteProvider(SDValue(N, 0), i, 0,
/*VectorIndex*/ std::nullopt, ExtractElements,

/*StartingIndex*/ i);
if (!P)
return SDValue();
Expand Down Expand Up @@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
if (!Allowed || !Fast)
return SDValue();

// calculatebyteProvider() allows multi-use for vector loads. Ensure that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

byte -> Byte

// all uses are in vector element extracts that are part of the pattern.
for (LoadSDNode *L : Loads)
if (L->getMemoryVT().isVector())
for (auto It = L->use_begin(); It != L->use_end(); ++It)
if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
return SDValue();

Copy link
Collaborator

@topperc topperc Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we go ahead and call DAG.makeEquivalentMemoryOrdering instead of ReplaceAllUsesOfValueWith below for maximum safety?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah, I've been looking everywhere for a function that does this, and couldn't find it. After looking at a few other uses of that function, it seems like we actually have quite a few folds that are happy to introduce extra loads to avoid other instructions, so it seems like strictly avoiding multi-use loads in this fold may not be desirable. I've opened #81586 to only use makeEquivalentMemoryOrdering() instead.

SDValue NewLoad =
DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
Chain, FirstLoad->getBasePtr(),
Expand Down
8 changes: 5 additions & 3 deletions llvm/test/CodeGen/AArch64/load-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) {
; CHECK-LABEL: short_vector_to_i32_unused_high_i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldr s0, [x0]
; CHECK-NEXT: ldrh w9, [x0]
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: umov w8, v0.h[2]
; CHECK-NEXT: orr w8, w9, w8, lsl #16
; CHECK-NEXT: umov w8, v0.h[1]
; CHECK-NEXT: umov w9, v0.h[0]
; CHECK-NEXT: umov w10, v0.h[2]
; CHECK-NEXT: bfi w9, w8, #8, #8
; CHECK-NEXT: orr w8, w9, w10, lsl #16
; CHECK-NEXT: str w8, [x1]
; CHECK-NEXT: ret
%ld = load <4 x i8>, ptr %in, align 4
Expand Down
10 changes: 6 additions & 4 deletions llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 {
; GCN-LABEL: load_3xi16_combine:
; GCN: ; %bb.0:
; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GCN-NEXT: global_load_dword v2, v[0:1], off
; GCN-NEXT: global_load_ushort v3, v[0:1], off offset:4
; GCN-NEXT: global_load_dword v3, v[0:1], off
; GCN-NEXT: global_load_ushort v2, v[0:1], off offset:4
; GCN-NEXT: s_mov_b32 s4, 0xffff
; GCN-NEXT: s_waitcnt vmcnt(1)
; GCN-NEXT: v_mov_b32_e32 v0, v2
; GCN-NEXT: v_and_b32_e32 v0, 0xffff0000, v3
; GCN-NEXT: v_and_or_b32 v0, v3, s4, v0
; GCN-NEXT: s_waitcnt vmcnt(0)
; GCN-NEXT: v_mov_b32_e32 v1, v3
; GCN-NEXT: v_mov_b32_e32 v1, v2
; GCN-NEXT: s_setpc_b64 s[30:31]
%gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1
%gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2
Expand Down
25 changes: 16 additions & 9 deletions llvm/test/CodeGen/X86/load-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) {
ret i32 %tmp8
}

; FIXME: This is a miscompile.
define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind {
; CHECK-LABEL: pr80911_vector_load_multiuse:
; CHECK: # %bb.0:
; CHECK-NEXT: pushl %edi
; CHECK-NEXT: pushl %esi
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %ecx
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %edx
; CHECK-NEXT: movl (%edx), %esi
; CHECK-NEXT: movzwl (%edx), %eax
; CHECK-NEXT: movl $0, (%ecx)
; CHECK-NEXT: movl %esi, (%edx)
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %esi
; CHECK-NEXT: movzbl (%esi), %ecx
; CHECK-NEXT: movzbl 1(%esi), %eax
; CHECK-NEXT: movzwl 2(%esi), %edi
; CHECK-NEXT: movl $0, (%edx)
; CHECK-NEXT: movw %di, 2(%esi)
; CHECK-NEXT: movb %al, 1(%esi)
; CHECK-NEXT: movb %cl, (%esi)
; CHECK-NEXT: shll $8, %eax
; CHECK-NEXT: orl %ecx, %eax
; CHECK-NEXT: popl %esi
; CHECK-NEXT: popl %edi
; CHECK-NEXT: retl
;
; CHECK64-LABEL: pr80911_vector_load_multiuse:
; CHECK64: # %bb.0:
; CHECK64-NEXT: movzwl (%rdi), %eax
; CHECK64-NEXT: movaps (%rdi), %xmm0
; CHECK64-NEXT: movl $0, (%rsi)
; CHECK64-NEXT: movl (%rdi), %ecx
; CHECK64-NEXT: movl %ecx, (%rdi)
; CHECK64-NEXT: movss %xmm0, (%rdi)
; CHECK64-NEXT: movaps %xmm0, -{{[0-9]+}}(%rsp)
; CHECK64-NEXT: movzwl -{{[0-9]+}}(%rsp), %eax
; CHECK64-NEXT: retq
%load = load <4 x i8>, ptr %ptr, align 16
store i32 0, ptr %clobber
Expand Down