5
5
from __future__ import print_function
6
6
from __future__ import unicode_literals
7
7
8
- from caffe2 .python import brew , workspace
8
+ from caffe2 .python import brew
9
9
from caffe2 .python .model_helper import ModelHelper
10
10
from caffe2 .proto import caffe2_pb2
11
11
import logging
@@ -17,7 +17,7 @@ class CNNModelHelper(ModelHelper):
17
17
"""
18
18
19
19
def __init__ (self , order = "NCHW" , name = None ,
20
- use_gpu_engine = True , gpu_engine_exhaustive_search = False ,
20
+ use_cudnn = True , cudnn_exhaustive_search = False ,
21
21
ws_nbytes_limit = None , init_params = True ,
22
22
skip_sparse_optim = False ,
23
23
param_model = None ):
@@ -31,8 +31,8 @@ def __init__(self, order="NCHW", name=None,
31
31
32
32
cnn_arg_scope = {
33
33
'order' : order ,
34
- 'use_gpu_engine ' : use_gpu_engine ,
35
- 'gpu_engine_exhaustive_search ' : gpu_engine_exhaustive_search ,
34
+ 'use_cudnn ' : use_cudnn ,
35
+ 'cudnn_exhaustive_search ' : cudnn_exhaustive_search ,
36
36
}
37
37
if ws_nbytes_limit :
38
38
cnn_arg_scope ['ws_nbytes_limit' ] = ws_nbytes_limit
@@ -45,8 +45,8 @@ def __init__(self, order="NCHW", name=None,
45
45
)
46
46
47
47
self .order = order
48
- self .use_gpu_engine = use_gpu_engine
49
- self .gpu_engine_exhaustive_search = gpu_engine_exhaustive_search
48
+ self .use_cudnn = use_cudnn
49
+ self .cudnn_exhaustive_search = cudnn_exhaustive_search
50
50
self .ws_nbytes_limit = ws_nbytes_limit
51
51
if self .order != "NHWC" and self .order != "NCHW" :
52
52
raise ValueError (
@@ -79,9 +79,9 @@ def ConvNd(self, *args, **kwargs):
79
79
return brew .conv_nd (
80
80
self ,
81
81
* args ,
82
- use_gpu_engine = self .use_gpu_engine ,
82
+ use_cudnn = self .use_cudnn ,
83
83
order = self .order ,
84
- gpu_engine_exhaustive_search = self .gpu_engine_exhaustive_search ,
84
+ cudnn_exhaustive_search = self .cudnn_exhaustive_search ,
85
85
ws_nbytes_limit = self .ws_nbytes_limit ,
86
86
** kwargs
87
87
)
@@ -90,9 +90,9 @@ def Conv(self, *args, **kwargs):
90
90
return brew .conv (
91
91
self ,
92
92
* args ,
93
- use_gpu_engine = self .use_gpu_engine ,
93
+ use_cudnn = self .use_cudnn ,
94
94
order = self .order ,
95
- gpu_engine_exhaustive_search = self .gpu_engine_exhaustive_search ,
95
+ cudnn_exhaustive_search = self .cudnn_exhaustive_search ,
96
96
ws_nbytes_limit = self .ws_nbytes_limit ,
97
97
** kwargs
98
98
)
@@ -101,9 +101,9 @@ def ConvTranspose(self, *args, **kwargs):
101
101
return brew .conv_transpose (
102
102
self ,
103
103
* args ,
104
- use_gpu_engine = self .use_gpu_engine ,
104
+ use_cudnn = self .use_cudnn ,
105
105
order = self .order ,
106
- gpu_engine_exhaustive_search = self .gpu_engine_exhaustive_search ,
106
+ cudnn_exhaustive_search = self .cudnn_exhaustive_search ,
107
107
ws_nbytes_limit = self .ws_nbytes_limit ,
108
108
** kwargs
109
109
)
@@ -112,9 +112,9 @@ def GroupConv(self, *args, **kwargs):
112
112
return brew .group_conv (
113
113
self ,
114
114
* args ,
115
- use_gpu_engine = self .use_gpu_engine ,
115
+ use_cudnn = self .use_cudnn ,
116
116
order = self .order ,
117
- gpu_engine_exhaustive_search = self .gpu_engine_exhaustive_search ,
117
+ cudnn_exhaustive_search = self .cudnn_exhaustive_search ,
118
118
ws_nbytes_limit = self .ws_nbytes_limit ,
119
119
** kwargs
120
120
)
@@ -123,9 +123,9 @@ def GroupConv_Deprecated(self, *args, **kwargs):
123
123
return brew .group_conv_deprecated (
124
124
self ,
125
125
* args ,
126
- use_gpu_engine = self .use_gpu_engine ,
126
+ use_cudnn = self .use_cudnn ,
127
127
order = self .order ,
128
- gpu_engine_exhaustive_search = self .gpu_engine_exhaustive_search ,
128
+ cudnn_exhaustive_search = self .cudnn_exhaustive_search ,
129
129
ws_nbytes_limit = self .ws_nbytes_limit ,
130
130
** kwargs
131
131
)
@@ -147,16 +147,16 @@ def FC_Sparse(self, *args, **kwargs):
147
147
148
148
def Dropout (self , * args , ** kwargs ):
149
149
return brew .dropout (
150
- self , * args , order = self .order , use_gpu_engine = self .use_gpu_engine , ** kwargs
150
+ self , * args , order = self .order , use_cudnn = self .use_cudnn , ** kwargs
151
151
)
152
152
153
153
def LRN (self , * args , ** kwargs ):
154
154
return brew .lrn (
155
- self , * args , order = self .order , use_gpu_engine = self .use_gpu_engine , ** kwargs
155
+ self , * args , order = self .order , use_cudnn = self .use_cudnn , ** kwargs
156
156
)
157
157
158
158
def Softmax (self , * args , ** kwargs ):
159
- return brew .softmax (self , * args , use_gpu_engine = self .use_gpu_engine , ** kwargs )
159
+ return brew .softmax (self , * args , use_cudnn = self .use_cudnn , ** kwargs )
160
160
161
161
def SpatialBN (self , * args , ** kwargs ):
162
162
return brew .spatial_bn (self , * args , order = self .order , ** kwargs )
@@ -169,7 +169,7 @@ def InstanceNorm(self, *args, **kwargs):
169
169
170
170
def Relu (self , * args , ** kwargs ):
171
171
return brew .relu (
172
- self , * args , order = self .order , use_gpu_engine = self .use_gpu_engine , ** kwargs
172
+ self , * args , order = self .order , use_cudnn = self .use_cudnn , ** kwargs
173
173
)
174
174
175
175
def PRelu (self , * args , ** kwargs ):
@@ -187,7 +187,7 @@ def Sum(self, *args, **kwargs):
187
187
return brew .sum (self , * args , ** kwargs )
188
188
189
189
def Transpose (self , * args , ** kwargs ):
190
- return brew .transpose (self , * args , use_gpu_engine = self .use_gpu_engine , ** kwargs )
190
+ return brew .transpose (self , * args , use_cudnn = self .use_cudnn , ** kwargs )
191
191
192
192
def Iter (self , * args , ** kwargs ):
193
193
return brew .iter (self , * args , ** kwargs )
@@ -197,15 +197,15 @@ def Accuracy(self, *args, **kwargs):
197
197
198
198
def MaxPool (self , * args , ** kwargs ):
199
199
return brew .max_pool (
200
- self , * args , use_gpu_engine = self .use_gpu_engine , order = self .order , ** kwargs
200
+ self , * args , use_cudnn = self .use_cudnn , order = self .order , ** kwargs
201
201
)
202
202
203
203
def MaxPoolWithIndex (self , * args , ** kwargs ):
204
204
return brew .max_pool_with_index (self , * args , order = self .order , ** kwargs )
205
205
206
206
def AveragePool (self , * args , ** kwargs ):
207
207
return brew .average_pool (
208
- self , * args , use_gpu_engine = self .use_gpu_engine , order = self .order , ** kwargs
208
+ self , * args , use_cudnn = self .use_cudnn , order = self .order , ** kwargs
209
209
)
210
210
211
211
@property
@@ -235,11 +235,6 @@ def CPU(self):
235
235
@property
236
236
def GPU (self , gpu_id = 0 ):
237
237
device_option = caffe2_pb2 .DeviceOption ()
238
- if workspace .has_hip_support :
239
- device_option .device_type = caffe2_pb2 .HIP
240
- device_option .hip_gpu_id = gpu_id
241
- else :
242
- device_option .device_type = caffe2_pb2 .CUDA
243
- device_option .cuda_gpu_id = gpu_id
244
-
238
+ device_option .device_type = caffe2_pb2 .CUDA
239
+ device_option .cuda_gpu_id = gpu_id
245
240
return device_option
0 commit comments