@@ -203,7 +203,7 @@ void OCLBackend::doForwardPass(bool isTrain) {
203
203
}
204
204
205
205
if (auto *SM = dyn_cast<SoftMaxInst>(I)) {
206
- // Implement Softmax by parallelizing the batsh dimension. Each sample in
206
+ // Implement Softmax by parallelizing the batch dimension. Each sample in
207
207
// the batch is processed by a different parallel 'thread'.
208
208
cl_kernel kernel = createKernel (program_, kernelName);
209
209
@@ -236,16 +236,31 @@ void OCLBackend::doForwardPass(bool isTrain) {
236
236
setKernelArg (kernel, arg + 1 , tensors_[I->getOperand (arg).first ]);
237
237
}
238
238
239
- auto odim = ShapeNHWC (CI->getDest ()->getType ()->dims ());
240
- auto idim = ShapeNHWC (CI->getSrc ()->getType ()->dims ());
241
- auto o = CI->getOffsets ();
242
- ShapeNHWC offset (o[0 ], o[1 ], o[2 ], o[3 ]);
239
+ // Currently support tensors of 2 and 4 dimensions.
240
+ // TODO: Handle other dimensions.
241
+ const size_t numDimensions = CI->getDest ()->getType ()->dims ().size ();
242
+ ShapeNHWC odim = ShapeNHWC::empty ();
243
+ ShapeNHWC idim = ShapeNHWC::empty ();
244
+ ShapeNHWC offset = ShapeNHWC::empty ();
245
+
246
+ if (numDimensions == 4 ) {
247
+ odim = ShapeNHWC (CI->getDest ()->getType ()->dims ());
248
+ idim = ShapeNHWC (CI->getSrc ()->getType ()->dims ());
249
+ offset = ShapeNHWC (CI->getOffsets ());
250
+ } else if (numDimensions == 2 ) {
251
+ odim = ShapeNHWC::fromXY (CI->getDest ()->getType ()->dims ());
252
+ idim = ShapeNHWC::fromXY (CI->getSrc ()->getType ()->dims ());
253
+ offset = ShapeNHWC::fromXY (CI->getOffsets ());
254
+ } else {
255
+ assert (false && " Unsupported tensor dimension" );
256
+ }
243
257
244
258
setKernelArg (kernel, 3 , odim);
245
259
setKernelArg (kernel, 4 , idim);
246
260
setKernelArg (kernel, 5 , offset);
247
261
enqueueKernel (commands_, kernel, deviceId_, {idim.n });
248
262
kernels.push_back (kernel);
263
+
249
264
continue ;
250
265
}
251
266
0 commit comments