Skip to content

Commit 41cf9bf

Browse files
Roman Dzhabarovrdzhabarov
Roman Dzhabarov
authored andcommitted
Fix insert tensor.
1 parent b27faa9 commit 41cf9bf

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

include/glow/Base/Type.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ struct ShapeNHWC {
3838
return ShapeNHWC(shape[0], shape[1], shape[2], 1);
3939
}
4040

41+
static ShapeNHWC fromXY(llvm::ArrayRef<size_t> shape) {
42+
assert(shape.size() == 2 && "Invalid 2d shape");
43+
return ShapeNHWC(shape[0], shape[1], 1, 1);
44+
}
45+
46+
static ShapeNHWC empty() {
47+
return ShapeNHWC(0, 0, 0, 0);
48+
}
49+
4150
explicit ShapeNHWC(size_t samples, size_t height, size_t width,
4251
size_t channels)
4352
: n(samples), h(height), w(width), c(channels) {}

lib/Backends/OpenCL/OpenCL.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ void OCLBackend::doForwardPass(bool isTrain) {
203203
}
204204

205205
if (auto *SM = dyn_cast<SoftMaxInst>(I)) {
206-
// Implement Softmax by parallelizing the batsh dimension. Each sample in
206+
// Implement Softmax by parallelizing the batch dimension. Each sample in
207207
// the batch is processed by a different parallel 'thread'.
208208
cl_kernel kernel = createKernel(program_, kernelName);
209209

@@ -236,16 +236,31 @@ void OCLBackend::doForwardPass(bool isTrain) {
236236
setKernelArg(kernel, arg + 1, tensors_[I->getOperand(arg).first]);
237237
}
238238

239-
auto odim = ShapeNHWC(CI->getDest()->getType()->dims());
240-
auto idim = ShapeNHWC(CI->getSrc()->getType()->dims());
241-
auto o = CI->getOffsets();
242-
ShapeNHWC offset(o[0], o[1], o[2], o[3]);
239+
// Currently support tensors of 2 and 4 dimensions.
240+
// TODO: Handle other dimensions.
241+
const size_t numDimensions = CI->getDest()->getType()->dims().size();
242+
ShapeNHWC odim = ShapeNHWC::empty();
243+
ShapeNHWC idim = ShapeNHWC::empty();
244+
ShapeNHWC offset = ShapeNHWC::empty();
245+
246+
if (numDimensions == 4) {
247+
odim = ShapeNHWC(CI->getDest()->getType()->dims());
248+
idim = ShapeNHWC(CI->getSrc()->getType()->dims());
249+
offset = ShapeNHWC(CI->getOffsets());
250+
} else if (numDimensions == 2) {
251+
odim = ShapeNHWC::fromXY(CI->getDest()->getType()->dims());
252+
idim = ShapeNHWC::fromXY(CI->getSrc()->getType()->dims());
253+
offset = ShapeNHWC::fromXY(CI->getOffsets());
254+
} else {
255+
assert(false && "Unsupported tensor dimension");
256+
}
243257

244258
setKernelArg(kernel, 3, odim);
245259
setKernelArg(kernel, 4, idim);
246260
setKernelArg(kernel, 5, offset);
247261
enqueueKernel(commands_, kernel, deviceId_, {idim.n});
248262
kernels.push_back(kernel);
263+
249264
continue;
250265
}
251266

0 commit comments

Comments
 (0)