diff --git a/caffe2/python/cnn.py b/caffe2/python/cnn.py index f927020e6ae8..05eaa1b0495e 100644 --- a/caffe2/python/cnn.py +++ b/caffe2/python/cnn.py @@ -5,7 +5,7 @@ from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import brew +from caffe2.python import brew, workspace from caffe2.python.model_helper import ModelHelper from caffe2.proto import caffe2_pb2 import logging @@ -17,7 +17,7 @@ class CNNModelHelper(ModelHelper): """ def __init__(self, order="NCHW", name=None, - use_cudnn=True, cudnn_exhaustive_search=False, + use_gpu_engine=True, gpu_engine_exhaustive_search=False, ws_nbytes_limit=None, init_params=True, skip_sparse_optim=False, param_model=None): @@ -31,8 +31,8 @@ def __init__(self, order="NCHW", name=None, cnn_arg_scope = { 'order': order, - 'use_cudnn': use_cudnn, - 'cudnn_exhaustive_search': cudnn_exhaustive_search, + 'use_gpu_engine': use_gpu_engine, + 'gpu_engine_exhaustive_search': gpu_engine_exhaustive_search, } if ws_nbytes_limit: cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit @@ -45,8 +45,8 @@ def __init__(self, order="NCHW", name=None, ) self.order = order - self.use_cudnn = use_cudnn - self.cudnn_exhaustive_search = cudnn_exhaustive_search + self.use_gpu_engine = use_gpu_engine + self.gpu_engine_exhaustive_search = gpu_engine_exhaustive_search self.ws_nbytes_limit = ws_nbytes_limit if self.order != "NHWC" and self.order != "NCHW": raise ValueError( @@ -79,9 +79,9 @@ def ConvNd(self, *args, **kwargs): return brew.conv_nd( self, *args, - use_cudnn=self.use_cudnn, + use_gpu_engine=self.use_gpu_engine, order=self.order, - cudnn_exhaustive_search=self.cudnn_exhaustive_search, + gpu_engine_exhaustive_search=self.gpu_engine_exhaustive_search, ws_nbytes_limit=self.ws_nbytes_limit, **kwargs ) @@ -90,9 +90,9 @@ def Conv(self, *args, **kwargs): return brew.conv( self, *args, - use_cudnn=self.use_cudnn, + use_gpu_engine=self.use_gpu_engine, order=self.order, - cudnn_exhaustive_search=self.cudnn_exhaustive_search, + gpu_engine_exhaustive_search=self.gpu_engine_exhaustive_search, ws_nbytes_limit=self.ws_nbytes_limit, **kwargs ) @@ -101,9 +101,9 @@ def ConvTranspose(self, *args, **kwargs): return brew.conv_transpose( self, *args, - use_cudnn=self.use_cudnn, + use_gpu_engine=self.use_gpu_engine, order=self.order, - cudnn_exhaustive_search=self.cudnn_exhaustive_search, + gpu_engine_exhaustive_search=self.gpu_engine_exhaustive_search, ws_nbytes_limit=self.ws_nbytes_limit, **kwargs ) @@ -112,9 +112,9 @@ def GroupConv(self, *args, **kwargs): return brew.group_conv( self, *args, - use_cudnn=self.use_cudnn, + use_gpu_engine=self.use_gpu_engine, order=self.order, - cudnn_exhaustive_search=self.cudnn_exhaustive_search, + gpu_engine_exhaustive_search=self.gpu_engine_exhaustive_search, ws_nbytes_limit=self.ws_nbytes_limit, **kwargs ) @@ -123,9 +123,9 @@ def GroupConv_Deprecated(self, *args, **kwargs): return brew.group_conv_deprecated( self, *args, - use_cudnn=self.use_cudnn, + use_gpu_engine=self.use_gpu_engine, order=self.order, - cudnn_exhaustive_search=self.cudnn_exhaustive_search, + gpu_engine_exhaustive_search=self.gpu_engine_exhaustive_search, ws_nbytes_limit=self.ws_nbytes_limit, **kwargs ) @@ -147,16 +147,16 @@ def FC_Sparse(self, *args, **kwargs): def Dropout(self, *args, **kwargs): return brew.dropout( - self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs + self, *args, order=self.order, use_gpu_engine=self.use_gpu_engine, **kwargs ) def LRN(self, *args, **kwargs): return brew.lrn( - self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs + self, *args, order=self.order, use_gpu_engine=self.use_gpu_engine, **kwargs ) def Softmax(self, *args, **kwargs): - return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs) + return brew.softmax(self, *args, use_gpu_engine=self.use_gpu_engine, **kwargs) def SpatialBN(self, *args, **kwargs): return brew.spatial_bn(self, *args, order=self.order, **kwargs) @@ -169,7 +169,7 @@ def InstanceNorm(self, *args, **kwargs): def Relu(self, *args, **kwargs): return brew.relu( - self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs + self, *args, order=self.order, use_gpu_engine=self.use_gpu_engine, **kwargs ) def PRelu(self, *args, **kwargs): @@ -187,7 +187,7 @@ def Sum(self, *args, **kwargs): return brew.sum(self, *args, **kwargs) def Transpose(self, *args, **kwargs): - return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs) + return brew.transpose(self, *args, use_gpu_engine=self.use_gpu_engine, **kwargs) def Iter(self, *args, **kwargs): return brew.iter(self, *args, **kwargs) @@ -197,7 +197,7 @@ def Accuracy(self, *args, **kwargs): def MaxPool(self, *args, **kwargs): return brew.max_pool( - self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs + self, *args, use_gpu_engine=self.use_gpu_engine, order=self.order, **kwargs ) def MaxPoolWithIndex(self, *args, **kwargs): @@ -205,7 +205,7 @@ def MaxPoolWithIndex(self, *args, **kwargs): def AveragePool(self, *args, **kwargs): return brew.average_pool( - self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs + self, *args, use_gpu_engine=self.use_gpu_engine, order=self.order, **kwargs ) @property @@ -235,6 +235,11 @@ def CPU(self): @property def GPU(self, gpu_id=0): device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = gpu_id + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = gpu_id + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = gpu_id + return device_option diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 7594491e0929..b8ab91187fa4 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -117,7 +117,7 @@ def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=Tr if not opt1.device_type or not opt2.device_type: # At least one option is for CPU, check if both are for CPU. return not opt1.device_type and not opt2.device_type - return opt1.cuda_gpu_id == opt2.cuda_gpu_id + return (opt1.cuda_gpu_id == opt2.cuda_gpu_id) and (opt1.hip_gpu_id == opt2.hip_gpu_id) def InferBlobDevices(net): @@ -2021,16 +2021,16 @@ def DeduplicateGradientSlices(self, g, aggregator='sum'): raise ValueError('{} is not supported'.format(aggregator)) return GradientSlice(indices=unique, values=new_g) - def RunAllOnGPU(self, gpu_id=0, use_cudnn=False): + def RunAllOnGPU(self, gpu_id=0, use_gpu_engine=False): """A convenient function to run everything on the GPU.""" device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA if workspace.has_gpu_support else caffe2_pb2.HIP + device_option.device_type = caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA device_option.cuda_gpu_id = gpu_id device_option.hip_gpu_id = gpu_id self._net.device_option.CopyFrom(device_option) - if use_cudnn: + if use_gpu_engine: for op in self._net.op: - op.engine = "CUDNN" + op.engine = "MIOPEN" if workspace.has_hip_support else 'CUDNN' def RunAllOnMKL(self): """A convenient function to run everything using MKLDNN.""" @@ -2189,27 +2189,38 @@ def extend_ops(self, new_ops): def copy_func_between_devices(src, dst): CPU = caffe2_pb2.CPU - CUDA = caffe2_pb2.CUDA + if workspace.has_hip_support: + GPU = caffe2_pb2.HIP` + else: + GPU = caffe2_pb2.CUDA if src.device_type == CPU and dst.device_type == CPU: return None - if src.device_type == CUDA and dst.device_type == CUDA: - if src.cuda_gpu_id == dst.cuda_gpu_id: - return None + if src.device_type == GPU and dst.device_type == GPU: + def fun(net, *args, **kw): + with DeviceScope(dst): + return net.Copy(*args, **kw) + + if workspace.has_hip_support: + if src.hip_gpu_id == dst.hip_gpu_id: + return None + else: + return fun else: - def fun(net, *args, **kw): - with DeviceScope(dst): - return net.Copy(*args, **kw) - return fun + if src.cuda_gpu_id == dst.cuda_gpu_id: + return None + else: + return fun + - if src.device_type == CUDA and dst.device_type == CPU: + if src.device_type == GPU and dst.device_type == CPU: def fun(net, *args, **kw): with DeviceScope(src): return net.CopyGPUToCPU(*args, **kw) return fun - if src.device_type == CPU and dst.device_type == CUDA: + if src.device_type == CPU and dst.device_type == GPU: def fun(net, *args, **kw): with DeviceScope(dst): return net.CopyCPUToGPU(*args, **kw) @@ -2224,7 +2235,12 @@ def device_equal(src, dst): comparison between empty device_options and {device_type:0, cuda_gpu_id:0} returns not equal in some cases. ''' - return src.device_type == dst.device_type and src.cuda_gpu_id == dst.cuda_gpu_id + if workspace.has_hip_support: + gpu_id_eq = src.hip_gpu_id == dst.hip_gpu_id + else: + gpu_id_eq = src.cuda_gpu_id == dst.cuda_gpu_id + + return src.device_type == dst.device_type and gpu_id_eq def update_placeholder_op_output(op, blob_to_device): @@ -2335,10 +2351,13 @@ def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None, def _gen_new_name(blob, device_option): CPU = caffe2_pb2.CPU CUDA = caffe2_pb2.CUDA + HIP = caffe2_pb2.HIP if device_option.device_type == CPU: suffix = '_cpu' elif device_option.device_type == CUDA: suffix = '_cuda_' + str(device_option.cuda_gpu_id) + elif device_option.device_type == HIP: + suffix = '_hip_' + str(device_option.hip_gpu_id) else: raise RuntimeError( "Unknown device type: {}". diff --git a/caffe2/python/core_test.py b/caffe2/python/core_test.py index b7099d20eae8..f47c79ac847a 100644 --- a/caffe2/python/core_test.py +++ b/caffe2/python/core_test.py @@ -82,18 +82,30 @@ def testDeviceScope(self): self.assertFalse(op.HasField('device_option')) # explicitly setting a device device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 - op = core.CreateOperator("Relu", "x", "y", device_option=device_option) - self.assertTrue(op.HasField('device_option')) - self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + op = core.CreateOperator("Relu", "x", "y", device_option=device_option) + self.assertTrue(op.HasField('device_option')) + self.assertEqual(op.device_option.device_type, caffe2_pb2.HIP) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 + op = core.CreateOperator("Relu", "x", "y", device_option=device_option) + self.assertTrue(op.HasField('device_option')) + self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) + self.assertEqual(op.device_option.cuda_gpu_id, 1) with core.DeviceScope(device_option): # from device scope op = core.CreateOperator("Relu", "x", "y") self.assertTrue(op.HasField('device_option')) - self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, caffe2_pb2.HIP) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) + self.assertEqual(op.device_option.cuda_gpu_id, 1) # from an overridden device option override_device = caffe2_pb2.DeviceOption() override_device.device_type = caffe2_pb2.CPU @@ -108,14 +120,22 @@ def testDeviceScope(self): def testNameAndDeviceScopeTogether(self): device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 with core.DeviceScope(device_option): with core.NameScope("foo"): op = core.CreateOperator("Relu", "x", "y") self.assertTrue(op.HasField('device_option')) - self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, caffe2_pb2.HIP) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) + self.assertEqual(op.device_option.cuda_gpu_id, 1) self.assertEqual(len(op.input), 1) self.assertEqual(op.input[0], "foo/x") self.assertEqual(len(op.output), 1) @@ -220,8 +240,12 @@ def testSetInputRecordWithoutBlobs(self): class TestCreateOperator(test_util.TestCase): def testCreate(self): device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 op = core.CreateOperator( "Ludicrous", "x", "y", name="ludicrous", control_input="z", device_option=device_option, @@ -236,8 +260,12 @@ def testCreate(self): self.assertEqual(len(op.control_input), 1) self.assertEqual(op.control_input[0], "z") self.assertTrue(op.HasField('device_option')) - self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, caffe2_pb2.HIP) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.device_option.device_type, caffe2_pb2.CUDA) + self.assertEqual(op.device_option.cuda_gpu_id, 1) self.assertTrue(len(op.arg), 3) # can't guarantee ordering of kwargs, so generate a set of args @@ -540,12 +568,15 @@ def test_check_equal_default_value(self): opt2 = caffe2_pb2.DeviceOption() opt1.device_type = 0 self.assertTrue(core.device_option_equal(opt1, opt2)) - opt1.cuda_gpu_id = 5 + if workspace.has_hip_support: + opt1.hip_gpu_id = 5 + else: + opt1.cuda_gpu_id = 5 # opt1 still is on CPU, so the options should be equal self.assertTrue(core.device_option_equal(opt1, opt2)) opt2.device_type = 0 self.assertTrue(core.device_option_equal(opt1, opt2)) - opt1.device_type = 1 + opt1.device_type = 6 if workspace.has_hip_support else 1 self.assertFalse(core.device_option_equal(opt1, opt2)) @@ -609,14 +640,18 @@ def test_inject_copy(self): self.assertEqual(op.input[2], "fc_b") -@unittest.skipIf(not workspace.has_gpu_support, 'No GPU support') +@unittest.skipIf(not workspace.has_gpu_support and not workspace.has_hip_support, 'No GPU support') class TestInferDevice(test_util.TestCase): def setUp(self): device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 - self.cuda_option = device_option + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 + self.gpu_option = device_option self.cpu_option = caffe2_pb2.DeviceOption() def _test_op( @@ -628,7 +663,7 @@ def _test_op( inputs=None, outputs=None ): - op_option = self.cuda_option if not op_option else op_option + op_option = self.gpu_option if not op_option else op_option inputs = ["blob_1"] if not inputs else inputs outputs = ["blob_2"] if not outputs else outputs with core.DeviceScope(op_option): @@ -656,9 +691,9 @@ def _test_op( def test_infer_device(self): self._test_op( "FC", - self.cuda_option, - self.cuda_option, - op_option=self.cuda_option, + self.gpu_option, + self.gpu_option, + op_option=self.gpu_option, inputs=["data", "fc_w", "fc_b"], outputs=["fc_1"] ) @@ -666,17 +701,17 @@ def test_infer_device(self): def test_infer_device_split_by_lengths(self): self._test_op( "SplitByLengths", - [self.cuda_option, self.cpu_option], - self.cuda_option, - op_option=self.cuda_option, + [self.gpu_option, self.cpu_option], + self.gpu_option, + op_option=self.gpu_option, inputs=["data", "fc_w"], outputs=["fc_1"] ) def test_infer_device_cross_device(self): - self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option) - self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option) - self._test_op("CopyFromCPUInput", self.cpu_option, self.cuda_option) + self._test_op("CopyGPUToCPU", self.gpu_option, self.cpu_option) + self._test_op("CopyCPUToGPU", self.cpu_option, self.gpu_option) + self._test_op("CopyFromCPUInput", self.cpu_option, self.gpu_option) self._test_op( "CopyFromCPUInput", self.cpu_option, @@ -686,7 +721,7 @@ def test_infer_device_cross_device(self): def test_device_inference_function(self): # ConcatOp. - op_option = self.cuda_option + op_option = self.gpu_option with core.DeviceScope(op_option): op = core.CreateOperator( 'Concat', @@ -698,7 +733,7 @@ def test_device_inference_function(self): self.assertEqual(output_dev[1], self.cpu_option) #SplitOp. - op_option = self.cuda_option + op_option = self.gpu_option with core.DeviceScope(op_option): op = core.CreateOperator( 'Split', @@ -713,8 +748,12 @@ def test_inject_copy(self): net = core.Net("test") init_net = core.Net("init") device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 weight = init_net.XavierFill([], 'fc_w', shape=[10, 100]) bias = init_net.ConstantFill([], 'fc_b', shape=[10, ]) @@ -727,11 +766,18 @@ def test_inject_copy(self): ) op = new_net._net.op[-1] self.assertEqual(op.type, "FC") - self.assertEqual(op.input[0], "data_cuda_1") - self.assertEqual(op.input[1], "fc_w_cuda_1") - self.assertEqual(op.input[2], "fc_b_cuda_1") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.input[0], "data_hip_1") + self.assertEqual(op.input[1], "fc_w_hip_1") + self.assertEqual(op.input[2], "fc_b_hip_1") + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.input[0], "data_cuda_1") + self.assertEqual(op.input[1], "fc_w_cuda_1") + self.assertEqual(op.input[2], "fc_b_cuda_1") + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) self.assertEqual(new_net._net.op[-2].type, "CopyCPUToGPU") self.assertEqual(new_net._net.op[0].type, "CopyCPUToGPU") self.assertNotEqual(blob_to_device["fc_w"], device_option) @@ -740,8 +786,12 @@ def test_cross_nets(self): net = core.Net("test") init_net = core.Net("init") device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 weight = init_net.XavierFill([], 'fc_w', shape=[10, 100]) bias = init_net.ConstantFill([], 'fc_b', shape=[10, ]) const = init_net.ConstantFill([], 'const', shape=[], value=1.) @@ -756,28 +806,53 @@ def test_cross_nets(self): ) op = nets[1]._net.op[0] self.assertEqual(op.type, "CopyCPUToGPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) - self.assertEqual(op.output[0], "fc_w_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + self.assertEqual(op.output[0], "fc_w_hip_1") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) + self.assertEqual(op.output[0], "fc_w_cuda_1") op = nets[1]._net.op[1] self.assertEqual(op.type, "CopyCPUToGPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) - self.assertEqual(op.output[0], "fc_b_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + self.assertEqual(op.output[0], "fc_b_hip_1") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) + self.assertEqual(op.output[0], "fc_b_cuda_1") op = nets[1]._net.op[2] self.assertEqual(op.type, "FC") self.assertEqual(op.input[0], "data") - self.assertEqual(op.input[1], "fc_w_cuda_1") - self.assertEqual(op.input[2], "fc_b_cuda_1") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.input[1], "fc_w_hip_1") + self.assertEqual(op.input[2], "fc_b_hip_1") + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.input[1], "fc_w_cuda_1") + self.assertEqual(op.input[2], "fc_b_cuda_1") + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) op = nets[1]._net.op[3] self.assertEqual(op.type, "Add") self.assertEqual(op.input[0], "fc1") - self.assertEqual(op.input[1], "const_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.input[1], "const_hip_1") + else: + self.assertEqual(op.input[1], "const_cuda_1") # check that moved blob is in input to the new net - for c in ["data", "fc_w", "fc_b", "const_cuda_1"]: - self.assertTrue(c in nets[1]._net.external_input) + if workspace.has_hip_support: + for c in ["data", "fc_w", "fc_b", "const_hip_1"]: + self.assertTrue(c in nets[1]._net.external_input) + else: + for c in ["data", "fc_w", "fc_b", "const_cuda_1"]: + self.assertTrue(c in nets[1]._net.external_input) """ For reference, net.Proto() should be like: name: "" @@ -877,8 +952,12 @@ def test_cross_nets_no_change(self): def test_inject_copy_multi_use(self): net = core.Net("test") device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 with core.DeviceScope(device_option): net.Relu("data", "relu1") @@ -886,23 +965,38 @@ def test_inject_copy_multi_use(self): with core.DeviceScope(device_option): net.Relu("data", "relu3") net.Relu("data", "relu4") - device_option.cuda_gpu_id = 0 + if workspace.has_hip_support: + device_option.hip_gpu_id = 0 + else: + device_option.cuda_gpu_id = 0 with core.DeviceScope(device_option): net.Relu("data", "relu5") - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.hip_gpu_id = 1 + else: + device_option.cuda_gpu_id = 1 with core.DeviceScope(device_option): net.Relu("data", "relu6") new_net, _ = core.InjectCrossDeviceCopies(net) op = new_net._net.op[0] self.assertEqual(op.type, "CopyCPUToGPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) - self.assertEqual(op.output[0], "data_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + self.assertEqual(op.output[0], "data_hip_1") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) + self.assertEqual(op.output[0], "data_cuda_1") op = new_net._net.op[1] self.assertEqual(op.type, "Relu") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) self.assertEqual(op.output[0], "relu1") op = new_net._net.op[2] self.assertEqual(op.type, "Relu") @@ -910,9 +1004,14 @@ def test_inject_copy_multi_use(self): self.assertEqual(op.output[0], "relu2") op = new_net._net.op[3] self.assertEqual(op.type, "Relu") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) - self.assertEqual(op.input[0], "data_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + self.assertEqual(op.input[0], "data_hip_1") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) + self.assertEqual(op.input[0], "data_cuda_1") self.assertEqual(op.output[0], "relu3") op = new_net._net.op[4] self.assertEqual(op.type, "Relu") @@ -920,20 +1019,35 @@ def test_inject_copy_multi_use(self): self.assertEqual(op.output[0], "relu4") op = new_net._net.op[5] self.assertEqual(op.type, "CopyCPUToGPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 0) - self.assertEqual(op.output[0], "data_cuda_0") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 0) + self.assertEqual(op.output[0], "data_hip_0") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 0) + self.assertEqual(op.output[0], "data_cuda_0") op = new_net._net.op[6] self.assertEqual(op.type, "Relu") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 0) - self.assertEqual(op.input[0], "data_cuda_0") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 0) + self.assertEqual(op.input[0], "data_hip_0") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 0) + self.assertEqual(op.input[0], "data_cuda_0") self.assertEqual(op.output[0], "relu5") op = new_net._net.op[7] self.assertEqual(op.type, "Relu") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 1) - self.assertEqual(op.input[0], "data_cuda_1") + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 1) + self.assertEqual(op.input[0], "data_hip_1") + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 1) + self.assertEqual(op.input[0], "data_cuda_1") self.assertEqual(op.output[0], "relu6") """ For reference, net.Proto() should be like: @@ -1025,8 +1139,12 @@ def test_inject_copy_placeholder_ops(self): cpu_device.append(caffe2_pb2.DeviceOption()) cpu_device[i].node_name = 'node:' + str(i) gpu_device.append(caffe2_pb2.DeviceOption()) - gpu_device[i].device_type = caffe2_pb2.CUDA - gpu_device[i].cuda_gpu_id = 0 + if workspace.has_hip_support: + gpu_device[i].device_type = caffe2_pb2.HIP + gpu_device[i].hip_gpu_id = 0 + else: + gpu_device[i].device_type = caffe2_pb2.CUDA + gpu_device[i].cuda_gpu_id = 0 gpu_device[i].node_name = 'node:' + str(i) send_node = 'node:0' recv_node = 'node:1' @@ -1065,13 +1183,21 @@ def test_inject_copy_placeholder_ops(self): # Verify (init_net) op = init_net._net.op[2] self.assertEqual(op.type, "CopyGPUToCPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 0) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 0) + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 0) self.assertEqual(op.output[0], "fc_w_cpu") op = init_net._net.op[3] self.assertEqual(op.type, "CopyGPUToCPU") - self.assertEqual(op.device_option.device_type, 1) - self.assertEqual(op.device_option.cuda_gpu_id, 0) + if workspace.has_hip_support: + self.assertEqual(op.device_option.device_type, 6) + self.assertEqual(op.device_option.hip_gpu_id, 0) + else: + self.assertEqual(op.device_option.device_type, 1) + self.assertEqual(op.device_option.cuda_gpu_id, 0) self.assertEqual(op.output[0], "fc_b_cpu") op = init_net._net.op[4] self.assertEqual(op.type, placeholder_send) @@ -1093,8 +1219,12 @@ def test_inject_copy_placeholder_ops(self): def test_blob_inplace(self): net = core.Net("test") device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA - device_option.cuda_gpu_id = 1 + if workspace.has_hip_support: + device_option.device_type = caffe2_pb2.HIP + device_option.hip_gpu_id = 1 + else: + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = 1 net.Adagrad(['param', 'moment', 'grad', 'lr'], ['param', 'moment']) with core.DeviceScope(device_option): @@ -1103,10 +1233,15 @@ def test_blob_inplace(self): op = net._net.op[1] self.assertEqual(op.type, 'CopyCPUToGPU') self.assertEqual(op.input[0], 'param') - self.assertEqual(op.output[0], 'param_cuda_1') + if workspace.has_hip_support: + self.assertEqual(op.output[0], 'param_hip_1') + else: + self.assertEqual(op.output[0], 'param_cuda_1') op = net._net.op[2] - self.assertEqual(op.input[0], 'param_cuda_1') - + if workspace.has_hip_support: + self.assertEqual(op.input[0], 'param_hip_1') + else: + self.assertEqual(op.input[0], 'param_cuda_1') net.Relu('nonsense_input', 'moment') # should not raise inplace error core.InjectCrossDeviceCopies(net) diff --git a/caffe2/python/functional_test.py b/caffe2/python/functional_test.py index e7803e829bb4..db252d8f704d 100644 --- a/caffe2/python/functional_test.py +++ b/caffe2/python/functional_test.py @@ -46,7 +46,7 @@ def _tensor_splits(draw, add_axis=False): class TestFunctional(hu.HypothesisTestCase): - @given(X=hu.tensor(), engine=st.sampled_from(["", "CUDNN"]), **hu.gcs) + @given(X=hu.tensor(), engine=st.sampled_from(["", "MIOPEN" if workspace.has_hip_support else "CUDNN"]), **hu.gcs) def test_relu(self, X, engine, gc, dc): X += 0.02 * np.sign(X) X[X == 0.0] += 0.02 diff --git a/caffe2/python/gradient_check_test.py b/caffe2/python/gradient_check_test.py index f1c190aa6efa..52da4e7dd493 100644 --- a/caffe2/python/gradient_check_test.py +++ b/caffe2/python/gradient_check_test.py @@ -23,9 +23,9 @@ import unittest -if workspace.has_gpu_support and workspace.NumCudaDevices() > 0: +if (workspace.has_gpu_support or workspace.has_hip_support) and workspace.NumGpuDevices() > 0: gpu_device_option = caffe2_pb2.DeviceOption() - gpu_device_option.device_type = caffe2_pb2.CUDA + gpu_device_option.device_type = caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA cpu_device_option = caffe2_pb2.DeviceOption() gpu_device_checker = device_checker.DeviceChecker( 0.01, [gpu_device_option] diff --git a/caffe2/python/helpers/conv.py b/caffe2/python/helpers/conv.py index bb88b2e3757f..fb1aabe153e6 100644 --- a/caffe2/python/helpers/conv.py +++ b/caffe2/python/helpers/conv.py @@ -5,7 +5,7 @@ from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import core +from caffe2.python import core, workspace from caffe2.python.modeling import initializers from caffe2.python.modeling.parameter_info import ParameterTags @@ -23,9 +23,9 @@ def _ConvBase( BiasInitializer=None, group=1, transform_inputs=None, - use_cudnn=False, + use_gpu_engine=False, order="NCHW", - cudnn_exhaustive_search=False, + gpu_engine_exhaustive_search=False, ws_nbytes_limit=None, float16_compute=False, **kwargs @@ -45,18 +45,22 @@ def _ConvBase( requested_engine = kwargs.get('engine') if requested_engine is not None: - if use_cudnn and requested_engine != 'CUDNN': + if workspace.has_gpu_support and use_gpu_engine and requested_engine != 'CUDNN': raise ValueError( - 'When use_cudnn=True, the only engine you can specify is ' + 'When use_gpu_engine=True and has CUDA GPU, the only engine you can specify is ' '"CUDNN"') - elif not use_cudnn and requested_engine == 'CUDNN': + elif workspace.has_hip_support and use_gpu_engine and requested_engine != 'MIOPEN': raise ValueError( - 'When use_cudnn=False, the only engine you can specify is ' - '""') + 'When use_gpu_engine=True and has HIP GPU, the only engine you can specify is ' + '"MIOPEN"') + elif not use_gpu_engine and (requested_engine in {'CUDNN','MIOPEN'}): + raise ValueError( + 'When use_gpu_engine=False, the only engine you can specify is ' + '""') - if use_cudnn: - kwargs['engine'] = 'CUDNN' - kwargs['exhaustive_search'] = cudnn_exhaustive_search + if use_gpu_engine: + kwargs['engine'] = 'MIOPEN' if workspace.has_hip_support else 'CUDNN' + kwargs['exhaustive_search'] = gpu_engine_exhaustive_search if ws_nbytes_limit: kwargs['ws_nbytes_limit'] = ws_nbytes_limit @@ -195,9 +199,9 @@ def conv_transpose( kernel, weight_init=None, bias_init=None, - use_cudnn=False, + use_gpu_engine=False, order="NCHW", - cudnn_exhaustive_search=False, + gpu_engine_exhaustive_search=False, ws_nbytes_limit=None, **kwargs ): @@ -230,9 +234,9 @@ def conv_transpose( blob_out + '_b', model.param_init_net) model.AddParameter(weight, ParameterTags.WEIGHT) model.AddParameter(bias, ParameterTags.BIAS) - if use_cudnn: - kwargs['engine'] = 'CUDNN' - kwargs['exhaustive_search'] = cudnn_exhaustive_search + if use_gpu_engine: + kwargs['engine'] = 'MIOPEN' if workspace.has_hip_support else 'CUDNN' + kwargs['exhaustive_search'] = gpu_engine_exhaustive_search if ws_nbytes_limit: kwargs['ws_nbytes_limit'] = ws_nbytes_limit return model.net.ConvTranspose( @@ -276,9 +280,9 @@ def group_conv_deprecated( weight_init=None, bias_init=None, group=1, - use_cudnn=False, + use_gpu_engine=False, order="NCHW", - cudnn_exhaustive_search=False, + gpu_engine_exhaustive_search=False, ws_nbytes_limit=None, **kwargs ): @@ -290,9 +294,9 @@ def group_conv_deprecated( weight_init = weight_init if weight_init else ('XavierFill', {}) bias_init = bias_init if bias_init else ('ConstantFill', {}) use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True - if use_cudnn: - kwargs['engine'] = 'CUDNN' - kwargs['exhaustive_search'] = cudnn_exhaustive_search + if use_gpu_engine: + kwargs['engine'] = 'MIOPEN' if workspace.has_hip_support else 'CUDNN' + kwargs['exhaustive_search'] = gpu_engine_exhaustive_search if ws_nbytes_limit: kwargs['ws_nbytes_limit'] = ws_nbytes_limit if dim_in % group: diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index cb9932bc4542..d10bfe209f7b 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -366,7 +366,7 @@ def test_recurrent(self, hidden_size, num_layers, bidirectional, rnn_mode, input_mode=input_mode, num_layers=num_layers, seed=seed, - engine="CUDNN") + engine="MIOPEN" if workspace.has_hip_support else "CUDNN") X = np.random.randn(T, N, D).astype(np.float32) self.ws.create_blob("INPUT").feed(X, device_option=hu.gpu_do) W = self.ws.blobs["WEIGHT"].fetch() diff --git a/caffe2/python/memonger_test.py b/caffe2/python/memonger_test.py index 6536280d8a60..cb5712f425e2 100644 --- a/caffe2/python/memonger_test.py +++ b/caffe2/python/memonger_test.py @@ -223,13 +223,13 @@ def test_gradient_optim(self, input_dim, output_dim, batch_size): np.testing.assert_almost_equal(loss, optimized_loss) np.testing.assert_almost_equal(grad, optimized_grad) - @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.") + @unittest.skipIf(not workspace.has_gpu_support and not workspace.has_hip_support, "No gpu support.") def test_memonger_mix_cpu_gpu(self): ''' Check that memonger does not make blobs cross CPU/GPU boundary ''' m = model_helper.ModelHelper() - with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 0)): + with core.DeviceScope(core.DeviceOption(caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA, 0)): fc1 = brew.fc(m, "data", "fc1", dim_in=2, dim_out=2) fc2 = brew.fc(m, fc1, "fc2", dim_in=2, dim_out=2) fc3 = brew.fc(m, fc2, "fc3", dim_in=2, dim_out=2) @@ -259,7 +259,10 @@ def test_memonger_mix_cpu_gpu(self): # Create set of blobs on CPU side and GPU side and check they don't # overlap - device_blobs = {caffe2_pb2.CPU: set(), caffe2_pb2.CUDA: set()} + if workspace.has_hip_support: + device_blobs = {caffe2_pb2.CPU: set(), caffe2_pb2.HIP: set()} + else: + device_blobs = {caffe2_pb2.CPU: set(), caffe2_pb2.CUDA: set()} for op in optim_proto.op: if op.type not in ['CopyCPUToGPU', "CopyGPUToCPU"]: dev = op.device_option.device_type @@ -267,7 +270,7 @@ def test_memonger_mix_cpu_gpu(self): device_blobs[dev].add(b) device_crossers = device_blobs[caffe2_pb2.CPU].intersection( - device_blobs[caffe2_pb2.CUDA] + device_blobs[caffe2_pb2.HIP] if workspace.has_hip_support else device_blobs[caffe2_pb2.CUDA] ) self.assertEquals(device_crossers, set()) diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index c033e0684bb5..79ac32b6269d 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -47,7 +47,7 @@ NumCudaDevices = C.num_cuda_devices GetCUDAVersion = C.get_cuda_version GetCuDNNVersion = C.get_cudnn_version - + NumGpuDevices = NumCudaDevices def GetCudaPeerAccessPattern(): return np.asarray(C.get_cuda_peer_access_pattern()) @@ -59,6 +59,18 @@ def GetCudaPeerAccessPattern(): GetCudaPeerAccessPattern = lambda: np.array([]) # noqa GetDeviceProperties = lambda x: None # noqa +if has_hip_support: + NumHipDevices = C.num_hip_devices + NumGpuDevices = NumHipDevices + def GetHipPeerAccessPattern(): + return np.asarray(C.get_hip_peer_access_pattern()) + + GetDeviceProperties = C.get_device_properties +else: + NumHipDevices = lambda: 0 # noqa + GetHipPeerAccessPattern = lambda: np.array([]) # noqa + GetDeviceProperties = lambda x: None # noqa + IsNUMAEnabled = C.is_numa_enabled GetNumNUMANodes = C.get_num_numa_nodes GetBlobNUMANode = C.get_blob_numa_node @@ -322,10 +334,10 @@ def FeedBlob(name, arr, device_option=None): if device_option is None: device_option = scope.CurrentDeviceScope() - if device_option and device_option.device_type == caffe2_pb2.CUDA: + if device_option and (device_option.device_type == caffe2_pb2.CUDA or device_option.device_type == caffe2_pb2.HIP): if arr.dtype == np.dtype('float64'): logger.warning( - "CUDA operators do not support 64-bit doubles, " + + "CUDA/HIP operators do not support 64-bit doubles, " + "please use arr.astype(np.float32) or np.int32 for ints." + " Blob: {}".format(name) + " type: {}".format(str(arr.dtype)) diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index 5da37c7f22ef..661a92d7c1f6 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -317,7 +317,7 @@ def testCreateWorkspace(self): self.assertTrue("test" in workspaces) -@unittest.skipIf(not workspace.has_gpu_support, "No gpu support.") +@unittest.skipIf(not workspace.has_gpu_support and not workspace.has_hip_support, "No gpu support.") class TestWorkspaceGPU(test_util.TestCase): def setUp(self): @@ -339,6 +339,7 @@ def testFetchBlobGPU(self): self.assertEqual(fetched_again.shape, (1, 2, 3, 4)) np.testing.assert_array_equal(fetched_again, 2.0) + @unittest.skipIf(not workspace.has_gpu_support, "No gpu support.") def testGetCudaPeerAccessPattern(self): pattern = workspace.GetCudaPeerAccessPattern() self.assertEqual(type(pattern), np.ndarray) @@ -346,6 +347,14 @@ def testGetCudaPeerAccessPattern(self): self.assertEqual(pattern.shape[0], pattern.shape[1]) self.assertEqual(pattern.shape[0], workspace.NumCudaDevices()) + @unittest.skipIf(not workspace.has_hip_support, "No hip support.") + def testGetHipPeerAccessPattern(self): + pattern = workspace.GetHipPeerAccessPattern() + self.assertEqual(type(pattern), np.ndarray) + self.assertEqual(pattern.ndim, 2) + self.assertEqual(pattern.shape[0], pattern.shape[1]) + self.assertEqual(pattern.shape[0], workspace.NumHipDevices()) + @unittest.skipIf(not workspace.C.has_mkldnn, "No MKLDNN support.") class TestWorkspaceMKLDNN(test_util.TestCase): @@ -580,8 +589,8 @@ def test_simple_transform(self, input_dim, output_dim, batch_size): conv = brew.conv(m, fc2, "conv", dim_in=output_dim, dim_out=output_dim, - use_cudnn=True, - engine="CUDNN", + use_gpu_engine=True, + engine="MIOPEN" if workspace.has_hip_support else "CUDNN", kernel=3) conv.Relu([], conv)\ @@ -622,8 +631,8 @@ def test_apply_transform_if_faster(self, value): dim_in=5, dim_out=5, kernel=3, - use_cudnn=True, - engine="CUDNN") + use_gpu_engine=True, + engine="MIOPEN" if workspace.has_hip_support else "CUDNN") conv.Relu([], conv)\ .Softmax([], "pred") \