Skip to content

Commit 7ec67b9

Browse files
committed
address pr feedback
1 parent c9f0d39 commit 7ec67b9

File tree

3 files changed

+48
-67
lines changed

3 files changed

+48
-67
lines changed

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

+27-38
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===---------------------------------------------------------------------===//
8-
98
///
109
/// \file This file contains a pass to flatten arrays for the DirectX Backend.
11-
//
10+
///
1211
//===----------------------------------------------------------------------===//
1312

1413
#include "DXILFlattenArrays.h"
@@ -26,10 +25,12 @@
2625
#include <cassert>
2726
#include <cstddef>
2827
#include <cstdint>
28+
#include <utility>
2929

3030
#define DEBUG_TYPE "dxil-flatten-arrays"
3131

3232
using namespace llvm;
33+
namespace {
3334

3435
class DXILFlattenArraysLegacy : public ModulePass {
3536

@@ -75,19 +76,18 @@ class DXILFlattenArraysVisitor
7576
bool visitCallInst(CallInst &ICI) { return false; }
7677
bool visitFreezeInst(FreezeInst &FI) { return false; }
7778
static bool isMultiDimensionalArray(Type *T);
78-
static unsigned getTotalElements(Type *ArrayTy);
79-
static Type *getBaseElementType(Type *ArrayTy);
79+
static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
8080

8181
private:
82-
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
82+
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
8383
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
8484
bool finish();
85-
ConstantInt *constFlattenIndices(ArrayRef<Value *> Indices,
86-
ArrayRef<uint64_t> Dims,
87-
IRBuilder<> &Builder);
88-
Value *instructionFlattenIndices(ArrayRef<Value *> Indices,
89-
ArrayRef<uint64_t> Dims,
90-
IRBuilder<> &Builder);
85+
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
86+
ArrayRef<uint64_t> Dims,
87+
IRBuilder<> &Builder);
88+
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
89+
ArrayRef<uint64_t> Dims,
90+
IRBuilder<> &Builder);
9191
void
9292
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
9393
ArrayType *FlattenedArrayType, Value *PtrOperand,
@@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor
9999
bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
100100
GetElementPtrInst &GEP);
101101
};
102+
} // namespace
102103

103104
bool DXILFlattenArraysVisitor::finish() {
104105
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
@@ -111,25 +112,18 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
111112
return false;
112113
}
113114

114-
unsigned DXILFlattenArraysVisitor::getTotalElements(Type *ArrayTy) {
115+
std::pair<unsigned, Type *>
116+
DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
115117
unsigned TotalElements = 1;
116118
Type *CurrArrayTy = ArrayTy;
117119
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118120
TotalElements *= InnerArrayTy->getNumElements();
119121
CurrArrayTy = InnerArrayTy->getElementType();
120122
}
121-
return TotalElements;
122-
}
123-
124-
Type *DXILFlattenArraysVisitor::getBaseElementType(Type *ArrayTy) {
125-
Type *CurrArrayTy = ArrayTy;
126-
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
127-
CurrArrayTy = InnerArrayTy->getElementType();
128-
}
129-
return CurrArrayTy;
123+
return std::make_pair(TotalElements, CurrArrayTy);
130124
}
131125

132-
ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
126+
ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
133127
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
134128
assert(Indices.size() == Dims.size() &&
135129
"Indicies and dimmensions should be the same");
@@ -146,7 +140,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
146140
return Builder.getInt32(FlatIndex);
147141
}
148142

149-
Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
143+
Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
150144
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
151145
if (Indices.size() == 1)
152146
return Indices[0];
@@ -202,10 +196,9 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
202196

203197
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
204198
IRBuilder<> Builder(&AI);
205-
unsigned TotalElements = getTotalElements(ArrType);
199+
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
206200

207-
ArrayType *FattenedArrayType =
208-
ArrayType::get(getBaseElementType(ArrType), TotalElements);
201+
ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
209202
AllocaInst *FlatAlloca =
210203
Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
211204
FlatAlloca->setAlignment(AI.getAlign());
@@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
261254
IRBuilder<> Builder(&GEP);
262255
Value *FlatIndex;
263256
if (GEPInfo.AllIndicesAreConstInt)
264-
FlatIndex = constFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
257+
FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
265258
else
266259
FlatIndex =
267-
instructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
260+
genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
268261

269262
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
270263
Value *FlatGEP =
@@ -285,9 +278,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
285278

286279
ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
287280
IRBuilder<> Builder(&GEP);
288-
unsigned TotalElements = getTotalElements(ArrType);
289-
ArrayType *FlattenedArrayType =
290-
ArrayType::get(getBaseElementType(ArrType), TotalElements);
281+
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
282+
ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
291283

292284
Value *PtrOperand = GEP.getPointerOperand();
293285

@@ -313,7 +305,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
313305

314306
bool DXILFlattenArraysVisitor::visit(Function &F) {
315307
bool MadeChange = false;
316-
////for (BasicBlock &BB : make_early_inc_range(F)) {
317308
ReversePostOrderTraversal<Function *> RPOT(&F);
318309
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
319310
for (Instruction &I : make_early_inc_range(*BB)) {
@@ -345,8 +336,7 @@ static void collectElements(Constant *Init,
345336
collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
346337
}
347338
} else {
348-
assert(
349-
false &&
339+
llvm_unreachable(
350340
"Expected a ConstantArray or ConstantDataArray for array initializer!");
351341
}
352342
}
@@ -382,10 +372,9 @@ flattenGlobalArrays(Module &M,
382372
continue;
383373

384374
ArrayType *ArrType = cast<ArrayType>(OrigType);
385-
unsigned TotalElements =
386-
DXILFlattenArraysVisitor::getTotalElements(ArrType);
387-
ArrayType *FattenedArrayType = ArrayType::get(
388-
DXILFlattenArraysVisitor::getBaseElementType(ArrType), TotalElements);
375+
auto [TotalElements, BaseType] =
376+
DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
377+
ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
389378

390379
// Create a new global variable with the updated type
391380
// Note: Initializer is set via transformInitializer

llvm/lib/Target/DirectX/DXILFlattenArrays.h

-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
#ifndef LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
1010
#define LLVM_TARGET_DIRECTX_DXILFLATTENARRAYS_H
1111

12-
#include "DXILResource.h"
1312
#include "llvm/IR/PassManager.h"
14-
#include "llvm/Pass.h"
1513

1614
namespace llvm {
1715

llvm/test/CodeGen/DirectX/flatten-array.ll

+21-27
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
1+
12
; RUN: opt -S -dxil-flatten-arrays %s | FileCheck %s
23

34
; CHECK-LABEL: alloca_2d_test
45
define void @alloca_2d_test () {
5-
; CHECK: alloca [9 x i32], align 4
6-
; CHECK-NOT: alloca [3 x [3 x i32]], align 4
7-
%1 = alloca [3 x [3 x i32]], align 4
8-
ret void
6+
; CHECK-NEXT: alloca [9 x i32], align 4
7+
; CHECK-NEXT: ret void
8+
;
9+
%1 = alloca [3 x [3 x i32]], align 4
10+
ret void
911
}
1012

1113
; CHECK-LABEL: alloca_3d_test
1214
define void @alloca_3d_test () {
13-
; CHECK: alloca [8 x i32], align 4
14-
; CHECK-NOT: alloca [2 x[2 x [2 x i32]]], align 4
15-
%1 = alloca [2 x[2 x [2 x i32]]], align 4
16-
ret void
15+
; CHECK-NEXT: alloca [8 x i32], align 4
16+
; CHECK-NEXT: ret void
17+
;
18+
%1 = alloca [2 x[2 x [2 x i32]]], align 4
19+
ret void
1720
}
1821

1922
; CHECK-LABEL: alloca_4d_test
2023
define void @alloca_4d_test () {
21-
; CHECK: alloca [16 x i32], align 4
22-
; CHECK-NOT: alloca [ 2x[2 x[2 x [2 x i32]]]], align 4
23-
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
24-
ret void
24+
; CHECK-NEXT: alloca [16 x i32], align 4
25+
; CHECK-NEXT: ret void
26+
;
27+
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
28+
ret void
2529
}
2630

2731
; CHECK-LABEL: gep_2d_test
2832
define void @gep_2d_test () {
2933
; CHECK: [[a:%.*]] = alloca [9 x i32], align 4
3034
; CHECK-COUNT-9: getelementptr inbounds [9 x i32], ptr [[a]], i32 {{[0-8]}}
31-
; CHECK-NOT: getelementptr inbounds [3 x [3 x i32]], ptr %1, i32 0, i32 0, i32 {{[0-2]}}
32-
; CHECK-NOT: getelementptr inbounds [3 x i32], [3 x i32]* {{.*}}, i32 0, i32 {{[0-2]}}
35+
; CHECK-NEXT: ret void
3336
%1 = alloca [3 x [3 x i32]], align 4
3437
%g2d0 = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %1, i32 0, i32 0
3538
%g1d_1 = getelementptr inbounds [3 x i32], [3 x i32]* %g2d0, i32 0, i32 0
@@ -51,9 +54,7 @@ define void @gep_2d_test () {
5154
define void @gep_3d_test () {
5255
; CHECK: [[a:%.*]] = alloca [8 x i32], align 4
5356
; CHECK-COUNT-8: getelementptr inbounds [8 x i32], ptr [[a]], i32 {{[0-7]}}
54-
; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
55-
; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
56-
; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
57+
; CHECK-NEXT: ret void
5758
%1 = alloca [2 x[2 x [2 x i32]]], align 4
5859
%g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %1, i32 0, i32 0
5960
%g2d0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g3d0, i32 0, i32 0
@@ -76,10 +77,7 @@ define void @gep_3d_test () {
7677
define void @gep_4d_test () {
7778
; CHECK: [[a:%.*]] = alloca [16 x i32], align 4
7879
; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[a]], i32 {{[0-9]|1[0-5]}}
79-
; CHECK-NOT: getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], ptr %1, i32 0, i32 0, i32 {{[0-1]}}
80-
; CHECK-NOT: getelementptr inbounds [2 x[2 x [2 x i32]]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
81-
; CHECK-NOT: getelementptr inbounds [2 x [2 x i32]], ptr {{.*}}, i32 0, i32 0, i32 {{[0-1]}}
82-
; CHECK-NOT: getelementptr inbounds [2 x i32], [2 x i32]* {{.*}}, i32 0, i32 {{[0-1]}}
80+
; CHECK-NEXT: ret void
8381
%1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4
8482
%g4d0 = getelementptr inbounds [2x[2 x[2 x [2 x i32]]]], [2x[2 x[2 x [2 x i32]]]]* %1, i32 0, i32 0
8583
%g3d0 = getelementptr inbounds [2 x[2 x [2 x i32]]], [2 x[2 x [2 x i32]]]* %g4d0, i32 0, i32 0
@@ -127,9 +125,7 @@ define void @gep_4d_test () {
127125
define void @global_gep_load() {
128126
; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 6
129127
; CHECK: load i32, ptr [[GEP_PTR]], align 4
130-
; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
131-
; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
132-
; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
128+
; CHECK-NEXT: ret void
133129
%1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 0
134130
%2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 1
135131
%3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 2
@@ -183,9 +179,7 @@ define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
183179
define void @global_gep_store() {
184180
; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @b.1dim, i32 13
185181
; CHECK: store i32 1, ptr [[GEP_PTR]], align 4
186-
; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
187-
; CHECK-NOT: getelementptr inbounds [3 x [4 x i32]]{{.*}}
188-
; CHECK-NOT: getelementptr inbounds [4 x i32]{{.*}}
182+
; CHECK-NEXT: ret void
189183
%1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @b, i32 0, i32 1
190184
%2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 0
191185
%3 = getelementptr inbounds [4 x i32], [4 x i32]* %2, i32 0, i32 1

0 commit comments

Comments
 (0)