20
20
#include " llvm/IR/DerivedTypes.h"
21
21
#include " llvm/IR/IRBuilder.h"
22
22
#include " llvm/IR/InstVisitor.h"
23
+ #include " llvm/IR/ReplaceConstant.h"
23
24
#include " llvm/Support/Casting.h"
24
25
#include " llvm/Transforms/Utils/Local.h"
25
26
#include < cassert>
@@ -69,8 +70,8 @@ class DXILFlattenArraysVisitor
69
70
bool visitExtractElementInst (ExtractElementInst &EEI) { return false ; }
70
71
bool visitShuffleVectorInst (ShuffleVectorInst &SVI) { return false ; }
71
72
bool visitPHINode (PHINode &PHI) { return false ; }
72
- bool visitLoadInst (LoadInst &LI) { return false ; }
73
- bool visitStoreInst (StoreInst &SI) { return false ; }
73
+ bool visitLoadInst (LoadInst &LI);
74
+ bool visitStoreInst (StoreInst &SI);
74
75
bool visitCallInst (CallInst &ICI) { return false ; }
75
76
bool visitFreezeInst (FreezeInst &FI) { return false ; }
76
77
static bool isMultiDimensionalArray (Type *T);
@@ -94,7 +95,6 @@ class DXILFlattenArraysVisitor
94
95
SmallVector<Value *> Indices = SmallVector<Value *>(),
95
96
SmallVector<uint64_t > Dims = SmallVector<uint64_t >(),
96
97
bool AllIndicesAreConstInt = true );
97
- ConstantInt *computeFlatIndex (GetElementPtrInst &GEP);
98
98
bool visitGetElementPtrInstInGEPChain (GetElementPtrInst &GEP);
99
99
bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
100
100
GetElementPtrInst &GEP);
@@ -164,6 +164,38 @@ Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
164
164
return FlatIndex;
165
165
}
166
166
167
+ bool DXILFlattenArraysVisitor::visitLoadInst (LoadInst &LI) {
168
+ unsigned NumOperands = LI.getNumOperands ();
169
+ for (unsigned I = 0 ; I < NumOperands; ++I) {
170
+ Value *CurrOpperand = LI.getOperand (I);
171
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
172
+ if (CE && CE->getOpcode () == Instruction::GetElementPtr) {
173
+ convertUsersOfConstantsToInstructions (CE,
174
+ /* RestrictToFunc=*/ nullptr ,
175
+ /* RemoveDeadConstants=*/ false ,
176
+ /* IncludeSelf=*/ true );
177
+ return false ;
178
+ }
179
+ }
180
+ return false ;
181
+ }
182
+
183
+ bool DXILFlattenArraysVisitor::visitStoreInst (StoreInst &SI) {
184
+ unsigned NumOperands = SI.getNumOperands ();
185
+ for (unsigned I = 0 ; I < NumOperands; ++I) {
186
+ Value *CurrOpperand = SI.getOperand (I);
187
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
188
+ if (CE && CE->getOpcode () == Instruction::GetElementPtr) {
189
+ convertUsersOfConstantsToInstructions (CE,
190
+ /* RestrictToFunc=*/ nullptr ,
191
+ /* RemoveDeadConstants=*/ false ,
192
+ /* IncludeSelf=*/ true );
193
+ return false ;
194
+ }
195
+ }
196
+ return false ;
197
+ }
198
+
167
199
bool DXILFlattenArraysVisitor::visitAllocaInst (AllocaInst &AI) {
168
200
if (!isMultiDimensionalArray (AI.getAllocatedType ()))
169
201
return false ;
@@ -182,41 +214,6 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
182
214
return true ;
183
215
}
184
216
185
- ConstantInt *
186
- DXILFlattenArraysVisitor::computeFlatIndex (GetElementPtrInst &GEP) {
187
- unsigned IndexAmount = GEP.getNumIndices ();
188
- assert (IndexAmount >= 1 && " Need At least one Index" );
189
- if (IndexAmount == 1 )
190
- return dyn_cast<ConstantInt>(GEP.getOperand (GEP.getNumOperands () - 1 ));
191
-
192
- // Get the type of the base pointer.
193
- Type *BaseType = GEP.getSourceElementType ();
194
-
195
- // Determine the dimensions of the multi-dimensional array.
196
- SmallVector<int64_t > Dimensions;
197
- while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
198
- Dimensions.push_back (ArrType->getNumElements ());
199
- BaseType = ArrType->getElementType ();
200
- }
201
- unsigned FlatIndex = 0 ;
202
- unsigned Multiplier = 1 ;
203
- unsigned BitWidth = 32 ;
204
- for (const Use &Index : GEP.indices ()) {
205
- ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
206
- BitWidth = CurrentIndex->getBitWidth ();
207
- if (!CurrentIndex)
208
- return nullptr ;
209
- int64_t IndexValue = CurrentIndex->getSExtValue ();
210
- FlatIndex += IndexValue * Multiplier;
211
-
212
- if (!Dimensions.empty ()) {
213
- Multiplier *= Dimensions.back (); // Use the last dimension size
214
- Dimensions.pop_back (); // Remove the last dimension
215
- }
216
- }
217
- return ConstantInt::get (GEP.getContext (), APInt (BitWidth, FlatIndex));
218
- }
219
-
220
217
void DXILFlattenArraysVisitor::recursivelyCollectGEPs (
221
218
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
222
219
Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
@@ -240,12 +237,13 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
240
237
for (auto *User : CurrGEP.users ()) {
241
238
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
242
239
recursivelyCollectGEPs (*NestedGEP, FlattenedArrayType, PtrOperand,
243
- ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
240
+ ++GEPChainUseCount, Indices, Dims,
241
+ AllIndicesAreConstInt);
244
242
GepUses = true ;
245
243
}
246
244
}
247
245
// This case is just incase the gep chain doesn't end with a 1d array.
248
- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
246
+ if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
249
247
GEPChainMap.insert (
250
248
{&CurrGEP,
251
249
{std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
@@ -295,10 +293,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
295
293
296
294
unsigned GEPChainUseCount = 0 ;
297
295
recursivelyCollectGEPs (GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
298
-
296
+
299
297
// NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
300
298
// Here recursion is used to get the length of the GEP chain.
301
- // Handle zero uses here because there won't be an update via
299
+ // Handle zero uses here because there won't be an update via
302
300
// a child in the chain later.
303
301
if (GEPChainUseCount == 0 ) {
304
302
SmallVector<Value *> Indices ({GEP.getOperand (GEP.getNumOperands () - 1 )});
@@ -308,7 +306,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
308
306
std::move (Indices), std::move (Dims), AllIndicesAreConstInt};
309
307
return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
310
308
}
311
-
309
+
312
310
PotentiallyDeadInstrs.emplace_back (&GEP);
313
311
return false ;
314
312
}
@@ -426,7 +424,7 @@ static bool flattenArrays(Module &M) {
426
424
for (auto &[Old, New] : GlobalMap) {
427
425
Old->replaceAllUsesWith (New);
428
426
Old->eraseFromParent ();
429
- MadeChange | = true ;
427
+ MadeChange = true ;
430
428
}
431
429
return MadeChange;
432
430
}
0 commit comments