Skip to content

Commit 1f7bb65

Browse files
ashishfarmeriotamudelta
authored andcommitted
[Caffe2] MIOpen dims change check (#229)
* Added checks to re-initialize miopen when dims change * remove double init of conv_desc_ * Update weights dim for check
1 parent 236e7d2 commit 1f7bb65

File tree

1 file changed

+129
-93
lines changed

1 file changed

+129
-93
lines changed

caffe2/operators/hip/conv_op_miopen.cc

Lines changed: 129 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,6 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
6666
dilation_h() == 1 && dilation_w() == 1,
6767
"MIOpen convolution does not support dilation for groups > 1.");
6868
}
69-
70-
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
71-
conv_desc_,
72-
mode_,
73-
pad_t(),
74-
pad_l(),
75-
stride_h(),
76-
stride_w(),
77-
dilation_h(),
78-
dilation_w()));
79-
80-
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
81-
conv_desc_, group_));
8269
}
8370

8471
~MIOPENConvOpBase() {
@@ -91,6 +78,8 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
9178
}
9279

9380
protected:
81+
vector<int64_t> mio_input_dims_;
82+
vector<int64_t> mio_weight_dims_;
9483
MIOPENWrapper miopen_wrapper_;
9584
miopenTensorDescriptor_t bottom_desc_;
9685
miopenTensorDescriptor_t bias_desc_;
@@ -257,35 +246,59 @@ bool MIOPENConvOp::DoRunWithType() {
257246
"If you set group, the number of output channels should be divisible "
258247
"by group.");
259248

260-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
261-
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
249+
bool input_changed = (X.dims() != mio_input_dims_);
250+
bool weight_changed = (Weight.dims() != mio_weight_dims_);
262251

263-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
264-
weight_desc_,
265-
miopenTypeWrapper<T_W>::type,
266-
M,
267-
C / group_,
268-
kernel_h(),
269-
kernel_w()));
252+
if (input_changed || weight_changed) {
253+
VLOG(1) << "Changing MIOpen descriptor configurations.";
254+
if (input_changed) {
255+
mio_input_dims_ = X.dims();
256+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
257+
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
258+
}
270259

271-
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
272-
conv_desc_,
273-
bottom_desc_,
274-
weight_desc_,
275-
&N_out,
276-
&C_out,
277-
&H_out,
278-
&W_out));
260+
if (weight_changed) {
261+
mio_weight_dims_ = Weight.dims();
262+
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
263+
conv_desc_,
264+
mode_,
265+
pad_t(),
266+
pad_l(),
267+
stride_h(),
268+
stride_w(),
269+
dilation_h(),
270+
dilation_w()));
279271

280-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
281-
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
272+
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
273+
conv_desc_, group_));
282274

283-
if (InputSize() == 3) {
275+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
276+
weight_desc_,
277+
miopenTypeWrapper<T_W>::type,
278+
M,
279+
C / group_,
280+
kernel_h(),
281+
kernel_w()));
282+
}
283+
284+
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
285+
conv_desc_,
286+
bottom_desc_,
287+
weight_desc_,
288+
&N_out,
289+
&C_out,
290+
&H_out,
291+
&W_out));
292+
293+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
294+
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
295+
296+
if (InputSize() == 3) {
284297
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
285298
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
286-
}
299+
}
287300

288-
while (!bestAlgoFound_) {
301+
while (!bestAlgoFound_) {
289302
miopenConvAlgoPerf_t perf;
290303

291304
MIOPEN_ENFORCE(miopenConvolutionForwardGetWorkSpaceSize(
@@ -318,8 +331,8 @@ bool MIOPENConvOp::DoRunWithType() {
318331
});
319332
bestAlgoFound_ = true;
320333
fwdAlgo_ = perf.fwd_algo;
334+
}
321335
}
322-
323336
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
324337
MIOPEN_ENFORCE(miopenConvolutionForward(
325338
state->miopen_handle(),
@@ -424,36 +437,59 @@ bool MIOPENConvGradientOp::DoRunWithType() {
424437
"by group.");
425438

426439
bool doBwdDataComputation = (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2)));
440+
bool input_changed = (X.dims() != mio_input_dims_);
441+
bool weight_changed = (Weight.dims() != mio_weight_dims_);
427442

428-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
429-
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
430-
431-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
432-
weight_desc_,
433-
miopenTypeWrapper<T_X>::type,
434-
M,
435-
C / group_,
436-
kernel_h(),
437-
kernel_w()));
443+
if (input_changed || weight_changed) {
444+
VLOG(1) << "Changing MIOpen descriptor configurations.";
445+
if (input_changed) {
446+
mio_input_dims_ = X.dims();
447+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
448+
bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
449+
}
438450

439-
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
440-
conv_desc_,
441-
bottom_desc_,
442-
weight_desc_,
443-
&N_out,
444-
&C_out,
445-
&H_out,
446-
&W_out));
451+
if (weight_changed) {
452+
mio_weight_dims_ = Weight.dims();
453+
MIOPEN_ENFORCE(miopenInitConvolutionDescriptor(
454+
conv_desc_,
455+
mode_,
456+
pad_t(),
457+
pad_l(),
458+
stride_h(),
459+
stride_w(),
460+
dilation_h(),
461+
dilation_w()));
447462

448-
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
449-
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
463+
MIOPEN_ENFORCE(miopenSetConvolutionGroupCount(
464+
conv_desc_, group_));
450465

451-
if (!no_bias_) {
452466
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
453-
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
454-
}
467+
weight_desc_,
468+
miopenTypeWrapper<T_X>::type,
469+
M,
470+
C / group_,
471+
kernel_h(),
472+
kernel_w()));
473+
}
474+
475+
MIOPEN_ENFORCE(miopenGetConvolutionForwardOutputDim(
476+
conv_desc_,
477+
bottom_desc_,
478+
weight_desc_,
479+
&N_out,
480+
&C_out,
481+
&H_out,
482+
&W_out));
483+
484+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
485+
top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
455486

456-
while ((!bestDataAlgoFound_) && doBwdDataComputation) {
487+
if (!no_bias_) {
488+
MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
489+
bias_desc_, miopenTypeWrapper<T_B>::type, 1, M, 1, 1));
490+
}
491+
492+
while ((!bestDataAlgoFound_) && doBwdDataComputation) {
457493
miopenConvAlgoPerf_t perf;
458494

459495
MIOPEN_ENFORCE(miopenConvolutionBackwardDataGetWorkSpaceSize(
@@ -487,43 +523,43 @@ bool MIOPENConvGradientOp::DoRunWithType() {
487523

488524
bestDataAlgoFound_ = true;
489525
bwdDataAlgo_ = perf.bwd_data_algo;
490-
}
526+
}
491527

492-
while (!bestWeightAlgoFound_) {
493-
miopenConvAlgoPerf_t perf;
528+
while (!bestWeightAlgoFound_) {
529+
miopenConvAlgoPerf_t perf;
494530

495-
MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize(
496-
miopen_wrapper_.inline_miopen_handle(),
497-
top_desc_,
498-
bottom_desc_,
499-
conv_desc_,
500-
weight_desc_,
501-
&bwdWeightWsSize_));
502-
if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
503-
HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
504-
}
531+
MIOPEN_ENFORCE(miopenConvolutionBackwardWeightsGetWorkSpaceSize(
532+
miopen_wrapper_.inline_miopen_handle(),
533+
top_desc_,
534+
bottom_desc_,
535+
conv_desc_,
536+
weight_desc_,
537+
&bwdWeightWsSize_));
538+
if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
539+
HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
540+
}
505541

506-
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
507-
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
508-
state->miopen_handle(),
509-
top_desc_,
510-
dY.template data<T_DY>(),
511-
bottom_desc_,
512-
X.template data<T_X>(),
513-
conv_desc_,
514-
weight_desc_,
515-
dW->template mutable_data<T_DW>(),
516-
requestAlgoCount_,
517-
&returnedAlgoCount_,
518-
&perf,
519-
bwdWeightWs_,
520-
bwdWeightWsSize_,
521-
false));
522-
});
523-
bestWeightAlgoFound_ = true;
524-
bwdWeiAlgo_ = perf.bwd_weights_algo;
542+
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
543+
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
544+
state->miopen_handle(),
545+
top_desc_,
546+
dY.template data<T_DY>(),
547+
bottom_desc_,
548+
X.template data<T_X>(),
549+
conv_desc_,
550+
weight_desc_,
551+
dW->template mutable_data<T_DW>(),
552+
requestAlgoCount_,
553+
&returnedAlgoCount_,
554+
&perf,
555+
bwdWeightWs_,
556+
bwdWeightWsSize_,
557+
false));
558+
});
559+
bestWeightAlgoFound_ = true;
560+
bwdWeiAlgo_ = perf.bwd_weights_algo;
561+
}
525562
}
526-
527563
if (doBwdDataComputation) {
528564
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
529565
MIOPEN_ENFORCE(miopenConvolutionBackwardData(

0 commit comments

Comments
 (0)