5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===---------------------------------------------------------------------===//
8
-
9
8
// /
10
9
// / \file This file contains a pass to flatten arrays for the DirectX Backend.
11
- //
10
+ // /
12
11
// ===----------------------------------------------------------------------===//
13
12
14
13
#include " DXILFlattenArrays.h"
26
25
#include < cassert>
27
26
#include < cstddef>
28
27
#include < cstdint>
28
+ #include < utility>
29
29
30
30
#define DEBUG_TYPE " dxil-flatten-arrays"
31
31
32
32
using namespace llvm ;
33
+ namespace {
33
34
34
35
class DXILFlattenArraysLegacy : public ModulePass {
35
36
@@ -75,19 +76,18 @@ class DXILFlattenArraysVisitor
75
76
bool visitCallInst (CallInst &ICI) { return false ; }
76
77
bool visitFreezeInst (FreezeInst &FI) { return false ; }
77
78
static bool isMultiDimensionalArray (Type *T);
78
- static unsigned getTotalElements (Type *ArrayTy);
79
- static Type *getBaseElementType (Type *ArrayTy);
79
+ static std::pair<unsigned , Type *> getElementCountAndType (Type *ArrayTy);
80
80
81
81
private:
82
- SmallVector<WeakTrackingVH, 32 > PotentiallyDeadInstrs;
82
+ SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
83
83
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
84
84
bool finish ();
85
- ConstantInt *constFlattenIndices (ArrayRef<Value *> Indices,
86
- ArrayRef<uint64_t > Dims,
87
- IRBuilder<> &Builder);
88
- Value *instructionFlattenIndices (ArrayRef<Value *> Indices,
89
- ArrayRef<uint64_t > Dims,
90
- IRBuilder<> &Builder);
85
+ ConstantInt *genConstFlattenIndices (ArrayRef<Value *> Indices,
86
+ ArrayRef<uint64_t > Dims,
87
+ IRBuilder<> &Builder);
88
+ Value *genInstructionFlattenIndices (ArrayRef<Value *> Indices,
89
+ ArrayRef<uint64_t > Dims,
90
+ IRBuilder<> &Builder);
91
91
void
92
92
recursivelyCollectGEPs (GetElementPtrInst &CurrGEP,
93
93
ArrayType *FlattenedArrayType, Value *PtrOperand,
@@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor
99
99
bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
100
100
GetElementPtrInst &GEP);
101
101
};
102
+ } // namespace
102
103
103
104
bool DXILFlattenArraysVisitor::finish () {
104
105
RecursivelyDeleteTriviallyDeadInstructionsPermissive (PotentiallyDeadInstrs);
@@ -111,25 +112,18 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
111
112
return false ;
112
113
}
113
114
114
- unsigned DXILFlattenArraysVisitor::getTotalElements (Type *ArrayTy) {
115
+ std::pair<unsigned , Type *>
116
+ DXILFlattenArraysVisitor::getElementCountAndType (Type *ArrayTy) {
115
117
unsigned TotalElements = 1 ;
116
118
Type *CurrArrayTy = ArrayTy;
117
119
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118
120
TotalElements *= InnerArrayTy->getNumElements ();
119
121
CurrArrayTy = InnerArrayTy->getElementType ();
120
122
}
121
- return TotalElements;
122
- }
123
-
124
- Type *DXILFlattenArraysVisitor::getBaseElementType (Type *ArrayTy) {
125
- Type *CurrArrayTy = ArrayTy;
126
- while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
127
- CurrArrayTy = InnerArrayTy->getElementType ();
128
- }
129
- return CurrArrayTy;
123
+ return std::make_pair (TotalElements, CurrArrayTy);
130
124
}
131
125
132
- ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices (
126
+ ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices (
133
127
ArrayRef<Value *> Indices, ArrayRef<uint64_t > Dims, IRBuilder<> &Builder) {
134
128
assert (Indices.size () == Dims.size () &&
135
129
" Indicies and dimmensions should be the same" );
@@ -146,7 +140,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
146
140
return Builder.getInt32 (FlatIndex);
147
141
}
148
142
149
- Value *DXILFlattenArraysVisitor::instructionFlattenIndices (
143
+ Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices (
150
144
ArrayRef<Value *> Indices, ArrayRef<uint64_t > Dims, IRBuilder<> &Builder) {
151
145
if (Indices.size () == 1 )
152
146
return Indices[0 ];
@@ -202,10 +196,9 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
202
196
203
197
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
204
198
IRBuilder<> Builder (&AI);
205
- unsigned TotalElements = getTotalElements (ArrType);
199
+ auto [ TotalElements, BaseType] = getElementCountAndType (ArrType);
206
200
207
- ArrayType *FattenedArrayType =
208
- ArrayType::get (getBaseElementType (ArrType), TotalElements);
201
+ ArrayType *FattenedArrayType = ArrayType::get (BaseType, TotalElements);
209
202
AllocaInst *FlatAlloca =
210
203
Builder.CreateAlloca (FattenedArrayType, nullptr , AI.getName () + " .flat" );
211
204
FlatAlloca->setAlignment (AI.getAlign ());
@@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
261
254
IRBuilder<> Builder (&GEP);
262
255
Value *FlatIndex;
263
256
if (GEPInfo.AllIndicesAreConstInt )
264
- FlatIndex = constFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
257
+ FlatIndex = genConstFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
265
258
else
266
259
FlatIndex =
267
- instructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
260
+ genInstructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
268
261
269
262
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType ;
270
263
Value *FlatGEP =
@@ -285,9 +278,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
285
278
286
279
ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType ());
287
280
IRBuilder<> Builder (&GEP);
288
- unsigned TotalElements = getTotalElements (ArrType);
289
- ArrayType *FlattenedArrayType =
290
- ArrayType::get (getBaseElementType (ArrType), TotalElements);
281
+ auto [TotalElements, BaseType] = getElementCountAndType (ArrType);
282
+ ArrayType *FlattenedArrayType = ArrayType::get (BaseType, TotalElements);
291
283
292
284
Value *PtrOperand = GEP.getPointerOperand ();
293
285
@@ -313,7 +305,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
313
305
314
306
bool DXILFlattenArraysVisitor::visit (Function &F) {
315
307
bool MadeChange = false ;
316
- // //for (BasicBlock &BB : make_early_inc_range(F)) {
317
308
ReversePostOrderTraversal<Function *> RPOT (&F);
318
309
for (BasicBlock *BB : make_early_inc_range (RPOT)) {
319
310
for (Instruction &I : make_early_inc_range (*BB)) {
@@ -345,8 +336,7 @@ static void collectElements(Constant *Init,
345
336
collectElements (DataArrayConstant->getElementAsConstant (I), Elements);
346
337
}
347
338
} else {
348
- assert (
349
- false &&
339
+ llvm_unreachable (
350
340
" Expected a ConstantArray or ConstantDataArray for array initializer!" );
351
341
}
352
342
}
@@ -382,10 +372,9 @@ flattenGlobalArrays(Module &M,
382
372
continue ;
383
373
384
374
ArrayType *ArrType = cast<ArrayType>(OrigType);
385
- unsigned TotalElements =
386
- DXILFlattenArraysVisitor::getTotalElements (ArrType);
387
- ArrayType *FattenedArrayType = ArrayType::get (
388
- DXILFlattenArraysVisitor::getBaseElementType (ArrType), TotalElements);
375
+ auto [TotalElements, BaseType] =
376
+ DXILFlattenArraysVisitor::getElementCountAndType (ArrType);
377
+ ArrayType *FattenedArrayType = ArrayType::get (BaseType, TotalElements);
389
378
390
379
// Create a new global variable with the updated type
391
380
// Note: Initializer is set via transformInitializer
0 commit comments