@@ -66,19 +66,6 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
66
66
dilation_h () == 1 && dilation_w () == 1 ,
67
67
" MIOpen convolution does not support dilation for groups > 1." );
68
68
}
69
-
70
- MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
71
- conv_desc_,
72
- mode_,
73
- pad_t (),
74
- pad_l (),
75
- stride_h (),
76
- stride_w (),
77
- dilation_h (),
78
- dilation_w ()));
79
-
80
- MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
81
- conv_desc_, group_));
82
69
}
83
70
84
71
~MIOPENConvOpBase () {
@@ -91,6 +78,8 @@ class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
91
78
}
92
79
93
80
protected:
81
+ vector<int64_t > mio_input_dims_;
82
+ vector<int64_t > mio_weight_dims_;
94
83
MIOPENWrapper miopen_wrapper_;
95
84
miopenTensorDescriptor_t bottom_desc_;
96
85
miopenTensorDescriptor_t bias_desc_;
@@ -257,35 +246,59 @@ bool MIOPENConvOp::DoRunWithType() {
257
246
" If you set group, the number of output channels should be divisible "
258
247
" by group." );
259
248
260
- MIOPEN_ENFORCE ( miopenSet4dTensorDescriptor (
261
- bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W) );
249
+ bool input_changed = (X. dims () != mio_input_dims_);
250
+ bool weight_changed = (Weight. dims () != mio_weight_dims_ );
262
251
263
- MIOPEN_ENFORCE ( miopenSet4dTensorDescriptor (
264
- weight_desc_,
265
- miopenTypeWrapper<T_W>::type,
266
- M,
267
- C / group_,
268
- kernel_h (),
269
- kernel_w ()));
252
+ if (input_changed || weight_changed) {
253
+ VLOG ( 1 ) << " Changing MIOpen descriptor configurations. " ;
254
+ if (input_changed) {
255
+ mio_input_dims_ = X. dims ();
256
+ MIOPEN_ENFORCE ( miopenSet4dTensorDescriptor (
257
+ bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
258
+ }
270
259
271
- MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
272
- conv_desc_,
273
- bottom_desc_,
274
- weight_desc_,
275
- &N_out,
276
- &C_out,
277
- &H_out,
278
- &W_out));
260
+ if (weight_changed) {
261
+ mio_weight_dims_ = Weight.dims ();
262
+ MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
263
+ conv_desc_,
264
+ mode_,
265
+ pad_t (),
266
+ pad_l (),
267
+ stride_h (),
268
+ stride_w (),
269
+ dilation_h (),
270
+ dilation_w ()));
279
271
280
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
281
- top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out ));
272
+ MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
273
+ conv_desc_, group_ ));
282
274
283
- if (InputSize () == 3 ) {
275
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
276
+ weight_desc_,
277
+ miopenTypeWrapper<T_W>::type,
278
+ M,
279
+ C / group_,
280
+ kernel_h (),
281
+ kernel_w ()));
282
+ }
283
+
284
+ MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
285
+ conv_desc_,
286
+ bottom_desc_,
287
+ weight_desc_,
288
+ &N_out,
289
+ &C_out,
290
+ &H_out,
291
+ &W_out));
292
+
293
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
294
+ top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
295
+
296
+ if (InputSize () == 3 ) {
284
297
MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
285
298
bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
286
- }
299
+ }
287
300
288
- while (!bestAlgoFound_) {
301
+ while (!bestAlgoFound_) {
289
302
miopenConvAlgoPerf_t perf;
290
303
291
304
MIOPEN_ENFORCE (miopenConvolutionForwardGetWorkSpaceSize (
@@ -318,8 +331,8 @@ bool MIOPENConvOp::DoRunWithType() {
318
331
});
319
332
bestAlgoFound_ = true ;
320
333
fwdAlgo_ = perf.fwd_algo ;
334
+ }
321
335
}
322
-
323
336
miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
324
337
MIOPEN_ENFORCE (miopenConvolutionForward (
325
338
state->miopen_handle (),
@@ -424,36 +437,59 @@ bool MIOPENConvGradientOp::DoRunWithType() {
424
437
" by group." );
425
438
426
439
bool doBwdDataComputation = (OutputSize () == 3 || (no_bias_ && (OutputSize () == 2 )));
440
+ bool input_changed = (X.dims () != mio_input_dims_);
441
+ bool weight_changed = (Weight.dims () != mio_weight_dims_);
427
442
428
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
429
- bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
430
-
431
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
432
- weight_desc_,
433
- miopenTypeWrapper<T_X>::type,
434
- M,
435
- C / group_,
436
- kernel_h (),
437
- kernel_w ()));
443
+ if (input_changed || weight_changed) {
444
+ VLOG (1 ) << " Changing MIOpen descriptor configurations." ;
445
+ if (input_changed) {
446
+ mio_input_dims_ = X.dims ();
447
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
448
+ bottom_desc_, miopenTypeWrapper<T_X>::type, N, C, H, W));
449
+ }
438
450
439
- MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
440
- conv_desc_,
441
- bottom_desc_,
442
- weight_desc_,
443
- &N_out,
444
- &C_out,
445
- &H_out,
446
- &W_out));
451
+ if (weight_changed) {
452
+ mio_weight_dims_ = Weight.dims ();
453
+ MIOPEN_ENFORCE (miopenInitConvolutionDescriptor (
454
+ conv_desc_,
455
+ mode_,
456
+ pad_t (),
457
+ pad_l (),
458
+ stride_h (),
459
+ stride_w (),
460
+ dilation_h (),
461
+ dilation_w ()));
447
462
448
- MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
449
- top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out ));
463
+ MIOPEN_ENFORCE (miopenSetConvolutionGroupCount (
464
+ conv_desc_, group_ ));
450
465
451
- if (!no_bias_) {
452
466
MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
453
- bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
454
- }
467
+ weight_desc_,
468
+ miopenTypeWrapper<T_X>::type,
469
+ M,
470
+ C / group_,
471
+ kernel_h (),
472
+ kernel_w ()));
473
+ }
474
+
475
+ MIOPEN_ENFORCE (miopenGetConvolutionForwardOutputDim (
476
+ conv_desc_,
477
+ bottom_desc_,
478
+ weight_desc_,
479
+ &N_out,
480
+ &C_out,
481
+ &H_out,
482
+ &W_out));
483
+
484
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
485
+ top_desc_, miopenTypeWrapper<T_X>::type, N_out, C_out, H_out, W_out));
455
486
456
- while ((!bestDataAlgoFound_) && doBwdDataComputation) {
487
+ if (!no_bias_) {
488
+ MIOPEN_ENFORCE (miopenSet4dTensorDescriptor (
489
+ bias_desc_, miopenTypeWrapper<T_B>::type, 1 , M, 1 , 1 ));
490
+ }
491
+
492
+ while ((!bestDataAlgoFound_) && doBwdDataComputation) {
457
493
miopenConvAlgoPerf_t perf;
458
494
459
495
MIOPEN_ENFORCE (miopenConvolutionBackwardDataGetWorkSpaceSize (
@@ -487,43 +523,43 @@ bool MIOPENConvGradientOp::DoRunWithType() {
487
523
488
524
bestDataAlgoFound_ = true ;
489
525
bwdDataAlgo_ = perf.bwd_data_algo ;
490
- }
526
+ }
491
527
492
- while (!bestWeightAlgoFound_) {
493
- miopenConvAlgoPerf_t perf;
528
+ while (!bestWeightAlgoFound_) {
529
+ miopenConvAlgoPerf_t perf;
494
530
495
- MIOPEN_ENFORCE (miopenConvolutionBackwardWeightsGetWorkSpaceSize (
496
- miopen_wrapper_.inline_miopen_handle (),
497
- top_desc_,
498
- bottom_desc_,
499
- conv_desc_,
500
- weight_desc_,
501
- &bwdWeightWsSize_));
502
- if ((bwdWeightWsSize_ > 0 ) && (bwdWeightWs_ == nullptr )) {
503
- HIP_CHECK (hipMalloc (&bwdWeightWs_, bwdWeightWsSize_));
504
- }
531
+ MIOPEN_ENFORCE (miopenConvolutionBackwardWeightsGetWorkSpaceSize (
532
+ miopen_wrapper_.inline_miopen_handle (),
533
+ top_desc_,
534
+ bottom_desc_,
535
+ conv_desc_,
536
+ weight_desc_,
537
+ &bwdWeightWsSize_));
538
+ if ((bwdWeightWsSize_ > 0 ) && (bwdWeightWs_ == nullptr )) {
539
+ HIP_CHECK (hipMalloc (&bwdWeightWs_, bwdWeightWsSize_));
540
+ }
505
541
506
- miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
507
- MIOPEN_ENFORCE (miopenFindConvolutionBackwardWeightsAlgorithm (
508
- state->miopen_handle (),
509
- top_desc_,
510
- dY.template data <T_DY>(),
511
- bottom_desc_,
512
- X.template data <T_X>(),
513
- conv_desc_,
514
- weight_desc_,
515
- dW->template mutable_data <T_DW>(),
516
- requestAlgoCount_,
517
- &returnedAlgoCount_,
518
- &perf,
519
- bwdWeightWs_,
520
- bwdWeightWsSize_,
521
- false ));
522
- });
523
- bestWeightAlgoFound_ = true ;
524
- bwdWeiAlgo_ = perf.bwd_weights_algo ;
542
+ miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
543
+ MIOPEN_ENFORCE (miopenFindConvolutionBackwardWeightsAlgorithm (
544
+ state->miopen_handle (),
545
+ top_desc_,
546
+ dY.template data <T_DY>(),
547
+ bottom_desc_,
548
+ X.template data <T_X>(),
549
+ conv_desc_,
550
+ weight_desc_,
551
+ dW->template mutable_data <T_DW>(),
552
+ requestAlgoCount_,
553
+ &returnedAlgoCount_,
554
+ &perf,
555
+ bwdWeightWs_,
556
+ bwdWeightWsSize_,
557
+ false ));
558
+ });
559
+ bestWeightAlgoFound_ = true ;
560
+ bwdWeiAlgo_ = perf.bwd_weights_algo ;
561
+ }
525
562
}
526
-
527
563
if (doBwdDataComputation) {
528
564
miopen_wrapper_.with_miopen_state (miopen_state_, [&](MIOPENState* state) {
529
565
MIOPEN_ENFORCE (miopenConvolutionBackwardData (
0 commit comments