16
16
#define DEBUG_TYPE " jit-allocations"
17
17
18
18
#include " AllocationsInfo.h"
19
+ #include " glow/Backends/CompiledFunction.h"
19
20
#include " glow/CodeGen/MemoryAllocator.h"
20
21
#include " glow/Graph/Context.h"
21
22
#include " glow/Graph/Graph.h"
22
23
#include " glow/Graph/Nodes.h"
23
24
#include " glow/IR/IRUtils.h"
24
25
#include " glow/IR/Instrs.h"
25
26
#include " glow/Support/Debug.h"
27
+ #include " glow/Support/Memory.h"
26
28
27
29
#include " llvm/Support/Debug.h"
28
30
#include " llvm/Support/raw_ostream.h"
@@ -32,9 +34,7 @@ using llvm::cast;
32
34
using llvm::dyn_cast;
33
35
using llvm::isa;
34
36
35
- void AllocationsInfo::allocateWeightVars (const IRFunction *F,
36
- const Context &ctx,
37
- bool absoluteAddr) {
37
+ void AllocationsInfo::allocateWeightVars (const IRFunction *F) {
38
38
// Use two different allocators, because constant weights and mutable weights
39
39
// may use different memory blocks.
40
40
MemoryAllocator constantWeightVarsAllocator (" ConstantWeights" , 0 );
@@ -43,48 +43,29 @@ void AllocationsInfo::allocateWeightVars(const IRFunction *F,
43
43
// Compute the new offsets for all the weights, do not reuse their current
44
44
// addresses. Process all constant WeightVars first.
45
45
for (auto &v : F->getGraph ()->getParent ()->getConstants ()) {
46
- assert (isa<WeightVar>(F->getWeightForNode (v)));
46
+ assert (isa<WeightVar>(F->getWeightForNode (v)) && " Expected WeightVar " );
47
47
auto *w = cast<WeightVar>(F->getWeightForNode (v));
48
48
auto numBytes = w->getSizeInBytes ();
49
49
size_t addr = constantWeightVarsAllocator.allocate (numBytes, w);
50
- if (!absoluteAddr) {
51
- allocatedAddressed_[w] = addr;
52
- } else {
53
- // Reuse the address used by the payload.
54
- allocatedAddressed_[w] =
55
- v->getPayload ().getUnsafePtr () - static_cast <char *>(nullptr );
56
- }
50
+ allocatedAddress_[w] = addr;
57
51
}
58
52
59
- if (absoluteAddr) {
60
- // Allocate addresses for the Placeholders that have payloads defined at
61
- // compile-time.
62
- // TODO: Remove this branch once Context becomes a parameter of the
63
- // CompiledFunction::execute method.
64
- for (auto PH : ctx.pairs ()) {
65
- assert (isa<WeightVar>(F->getWeightForNode (PH.first )));
66
- auto *w = cast<WeightVar>(F->getWeightForNode (PH.first ));
67
- // Reuse the address used by the payload.
68
- allocatedAddressed_[w] =
69
- PH.second ->getUnsafePtr () - static_cast <char *>(nullptr );
70
- }
71
- } else {
72
- // Allocate based on size as reported by the formal type of Placeholders
73
- for (auto &v : F->getGraph ()->getParent ()->getPlaceholders ()) {
74
- assert (isa<WeightVar>(F->getWeightForNode (v)));
75
- auto *w = cast<WeightVar>(F->getWeightForNode (v));
76
- auto numBytes = w->getSizeInBytes ();
77
- size_t addr = mutableWeightVarsAllocator.allocate (numBytes, w);
78
- allocatedAddressed_[w] = addr;
79
- }
53
+ // Compute the offsets and total memory requirements for Placeholders.
54
+ for (auto &v : F->getGraph ()->getParent ()->getPlaceholders ()) {
55
+ // Get the WeightVar for each Placeholder to calculate offsets.
56
+ assert (isa<WeightVar>(F->getWeightForNode (v)) && " Expected WeightVar" );
57
+ auto *w = cast<WeightVar>(F->getWeightForNode (v));
58
+ auto numBytes = w->getSizeInBytes ();
59
+ size_t addr = mutableWeightVarsAllocator.allocate (numBytes, w);
60
+ allocatedAddress_[w] = addr;
80
61
}
81
62
82
63
// Remember that max required memory size for each kind of weights.
83
64
constantWeightVarsMemSize_ = constantWeightVarsAllocator.getMaxMemoryUsage ();
84
65
mutableWeightVarsMemSize_ = mutableWeightVarsAllocator.getMaxMemoryUsage ();
85
66
86
67
DEBUG_GLOW (for (auto &A
87
- : allocatedAddressed_ ) {
68
+ : allocatedAddress_ ) {
88
69
if (isa<AllocActivationInst>(A.first ) || isa<TensorViewInst>(A.first ))
89
70
continue ;
90
71
assert (valueNumbers_.count (A.first ) && " Unknown weight" );
@@ -94,13 +75,47 @@ void AllocationsInfo::allocateWeightVars(const IRFunction *F,
94
75
: " mutable weight" ;
95
76
llvm::errs () << " Allocated " << kind << " " << A.first ->getName ()
96
77
<< " size: " << A.first ->getSizeInBytes ()
97
- << " address range: [" << allocatedAddressed_[A.first ]
98
- << " , "
99
- << allocatedAddressed_[A.first ] + A.first ->getSizeInBytes ()
78
+ << " address range: [" << allocatedAddress_[A.first ] << " , "
79
+ << allocatedAddress_[A.first ] + A.first ->getSizeInBytes ()
100
80
<< " ]\n " ;
101
81
});
102
82
}
103
83
84
+ void AllocationsInfo::collectConstants (const IRFunction *F) {
85
+
86
+ // At compile time condense constants to a single block of memory.
87
+ // This allows the graph to go away after compile time.
88
+ baseConstantWeightVarsStore_ =
89
+ (uint8_t *)alignedAlloc (constantWeightVarsMemSize_, TensorAlignment);
90
+ for (auto &v : F->getGraph ()->getParent ()->getConstants ()) {
91
+ assert (isa<WeightVar>(F->getWeightForNode (v)));
92
+ auto *w = cast<WeightVar>(F->getWeightForNode (v));
93
+ auto payload = v->getPayload ().getUnsafePtr ();
94
+ auto numBytes = w->getSizeInBytes ();
95
+ auto addr = allocatedAddress_[w];
96
+ // Copy weight to offset.
97
+ memcpy (baseConstantWeightVarsStore_ + addr, payload, numBytes);
98
+ }
99
+ }
100
+
101
+ runtime::RuntimeBundle
102
+ AllocationsInfo::generateRuntimeBundle (const IRFunction *F) {
103
+ runtime::RuntimeBundle info (constantWeightVarsMemSize_,
104
+ mutableWeightVarsMemSize_, activationsMemSize_);
105
+ std::unordered_map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
106
+ info.constants = baseConstantWeightVarsStore_;
107
+ for (auto &v : F->getGraph ()->getParent ()->getPlaceholders ()) {
108
+ assert (isa<WeightVar>(F->getWeightForNode (v)) && " Expected WeightVar" );
109
+ auto *w = cast<WeightVar>(F->getWeightForNode (v));
110
+ runtime::RuntimeSymbolInfo symbol;
111
+ symbol.offset = allocatedAddress_[w];
112
+ symbol.size = w->getSizeInBytes ();
113
+ symbolTable.emplace (std::string (v->getName ()), symbol);
114
+ }
115
+ info.symbolTable = std::move (symbolTable);
116
+ return info;
117
+ }
118
+
104
119
void AllocationsInfo::allocateActivations (const IRFunction *F) {
105
120
// Use a memory allocator with no upper bound on how much memory we can
106
121
// allocate.
@@ -131,15 +146,14 @@ void AllocationsInfo::allocateActivations(const IRFunction *F) {
131
146
132
147
// Register specific addresses within the heap to activations.
133
148
for (auto &A : activationAddr) {
134
- allocatedAddressed_ [A.first ] = A.second ;
149
+ allocatedAddress_ [A.first ] = A.second ;
135
150
}
136
151
DEBUG_GLOW (for (auto &A
137
- : allocatedAddressed_ ) {
152
+ : allocatedAddress_ ) {
138
153
llvm::errs () << " Allocated activation " << A.first ->getName ()
139
154
<< " size: " << A.first ->getSizeInBytes ()
140
- << " address range: [" << allocatedAddressed_[A.first ]
141
- << " , "
142
- << allocatedAddressed_[A.first ] + A.first ->getSizeInBytes ()
155
+ << " address range: [" << allocatedAddress_[A.first ] << " , "
156
+ << allocatedAddress_[A.first ] + A.first ->getSizeInBytes ()
143
157
<< " ]\n " ;
144
158
});
145
159
}
@@ -174,18 +188,18 @@ void AllocationsInfo::allocateTensorViews(const IRFunction *F) {
174
188
for (const auto &I : F->getInstrs ()) {
175
189
if (const auto *TVI = dyn_cast<TensorViewInst>(&I)) {
176
190
auto *viewOrigin = getOrigin (TVI);
177
- assert (allocatedAddressed_ .count (viewOrigin) &&
191
+ assert (allocatedAddress_ .count (viewOrigin) &&
178
192
" Did not find original WeightVar or AllocActivation for a "
179
193
" TensorView." );
180
- size_t originAddr = allocatedAddressed_ [viewOrigin];
194
+ size_t originAddr = allocatedAddress_ [viewOrigin];
181
195
182
196
// Calculate the offset into the underlying alloc activation.
183
197
size_t offset = calculateTensorViewOffset (TVI);
184
198
185
199
// Calculate the correct address using this offset into the alloc
186
200
// activation and map from the original TVI to it.
187
- assert (!allocatedAddressed_ .count (TVI) && " Allocation already made!" );
188
- allocatedAddressed_ [TVI] = originAddr + offset;
201
+ assert (!allocatedAddress_ .count (TVI) && " Allocation already made!" );
202
+ allocatedAddress_ [TVI] = originAddr + offset;
189
203
continue ;
190
204
}
191
205
}
0 commit comments