Skip to content

[Caffe2] Support for HIP in python operator tests #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion caffe2/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,7 @@ def extend_ops(self, new_ops):
def copy_func_between_devices(src, dst):
CPU = caffe2_pb2.CPU
if workspace.has_hip_support:
GPU = caffe2_pb2.HIP`
GPU = caffe2_pb2.HIP
else:
GPU = caffe2_pb2.CUDA

Expand Down
6 changes: 3 additions & 3 deletions caffe2/python/operator_test/activation_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class TestActivations(hu.HypothesisTestCase):
@given(X=hu.tensor(), in_place=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **mu.gcs)
engine=st.sampled_from(["", "MIOPEN" if workspace.has_hip_support else "CUDNN"]), **mu.gcs)
def test_relu(self, X, in_place, engine, gc, dc):
if gc == mu.mkl_do:
in_place = False
Expand All @@ -43,7 +43,7 @@ def relu_ref(X):
@unittest.skipIf(not workspace.has_gpu_support,
"Relu for float16 can only run on GPU now.")
@given(X=hu.tensor(dtype=np.float16), in_place=st.booleans(),
engine=st.sampled_from(["", "CUDNN"]), **hu.gcs_gpu_only)
engine=st.sampled_from([""] if workspace.has_hip_support else ["", "CUDNN"]), **hu.gcs_gpu_only)
def test_relu_fp16(self, X, in_place, engine, gc, dc):
op = core.CreateOperator(
"Relu",
Expand Down Expand Up @@ -102,7 +102,7 @@ def relu_n_ref(X):

@given(X=hu.tensor(),
alpha=st.floats(min_value=0.1, max_value=2.0),
in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]),
in_place=st.booleans(), engine=st.sampled_from([""] if workspace.has_hip_support else ["", "CUDNN"]),
**hu.gcs)
def test_elu(self, X, alpha, in_place, engine, gc, dc):
op = core.CreateOperator(
Expand Down
8 changes: 4 additions & 4 deletions caffe2/python/operator_test/boolean_mask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hypothesis.strategies as st

from caffe2.proto import caffe2_pb2
from caffe2.python import core
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu


Expand Down Expand Up @@ -47,10 +47,10 @@ def ref(x, mask):

@staticmethod
def _dtype_conversion(x, dtype, gc, dc):
"""SequenceMask only supports fp16 with CUDA."""
"""SequenceMask only supports fp16 with CUDA/HIP."""
if dtype == np.float16:
assume(gc.device_type == caffe2_pb2.CUDA)
dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA]
assume(gc.device_type == caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA)
dc = [d for d in dc if d.device_type == (caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA)]
x = x.astype(dtype)
return x, dc

Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/ceil_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core
from caffe2.python import core, workspace
from hypothesis import given
import hypothesis.strategies as st
import caffe2.python.hypothesis_test_util as hu
Expand All @@ -15,7 +15,7 @@
class TestCeil(hu.HypothesisTestCase):

@given(X=hu.tensor(),
engine=st.sampled_from(["", "CUDNN"]),
engine=st.sampled_from([""] if workspace.has_hip_support else ["", "CUDNN"]),
**hu.gcs)
def test_ceil(self, X, gc, dc, engine):
op = core.CreateOperator("Ceil", ["X"], ["Y"], engine=engine)
Expand Down
61 changes: 41 additions & 20 deletions caffe2/python/operator_test/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def _cudnn_supports(
return False
return True

def _miopen_supports(
dilation=False,
nhwc=False,
backward=False,
):
"""Return True if MIOPEN supports this configuration."""
if nhwc or dilation:
return False
return True

def _cudnn_convolution_algo_count(direction):
try:
Expand Down Expand Up @@ -192,7 +201,7 @@ def test_convolution_separate_stride_pad_layout(self, op_type,
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "MKLDNN"]),
engine=st.sampled_from(["", "MIOPEN" if workspace.has_hip_support else "CUDNN", "MKLDNN"]),
use_bias=st.booleans(),
force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
Expand All @@ -209,6 +218,10 @@ def test_convolution_gradients(self, op_type, stride, pad, kernel, dilation,
assume(_cudnn_supports(dilation=(dilation > 1),
nhwc=(order == 'NHWC'),
backward=True))
if engine == 'MIOPEN':
assume(_cudnn_supports(dilation=(dilation > 1),
nhwc=(order == 'NHWC'),
backward=True))

assume(engine != "MKLDNN" or use_bias is True)

Expand Down Expand Up @@ -451,8 +464,12 @@ def test_convolution_layout(self, op_type, stride, pad, kernel, dilation,

for order in ["NCHW", "NHWC"]:
engine_list = ['']
if _cudnn_supports(dilation=(dilation > 1), nhwc=(order == 'NHWC')):
engine_list.append('CUDNN')
if workspace.has_hip_support:
if _miopen_supports(dilation=(dilation > 1), nhwc=(order == 'NHWC')):
engine_list.append('MIOPEN')
else:
if _cudnn_supports(dilation=(dilation > 1), nhwc=(order == 'NHWC')):
engine_list.append('CUDNN')

for engine in engine_list:
op = core.CreateOperator(
Expand Down Expand Up @@ -504,7 +521,7 @@ def canonical(o):
["simple", "dag"] +
(["async_dag"] if workspace.has_gpu_support or workspace.has_hip_support else [])),
do=st.sampled_from(hu.device_options),
engine=st.sampled_from(["CUDNN", ""]))
engine=st.sampled_from(["MIOPEN" if workspace.has_hip_support else "CUDNN", ""]))
def test_convolution_sync(self, net_type, num_workers, do, engine):
m = ModelHelper(name="test_model")
n = 1
Expand All @@ -515,7 +532,7 @@ def test_convolution_sync(self, net_type, num_workers, do, engine):
w = 5
workspace.ResetWorkspace()

use_cudnn = (engine == 'CUDNN')
use_gpu_engine = (engine == 'CUDNN' or engine == 'MIOPEN')

np.random.seed(1701)
# Build a binary tree of conv layers, summing at each node.
Expand All @@ -537,7 +554,7 @@ def test_convolution_sync(self, net_type, num_workers, do, engine):
stride=1,
pad=1,
deterministic=1,
use_cudnn=use_cudnn,
use_gpu_engine=use_gpu_engine,
engine=engine)
brew.conv(
m, bottom_2, mid_2,
Expand All @@ -549,7 +566,7 @@ def test_convolution_sync(self, net_type, num_workers, do, engine):
bias_init=('ConstantFill', dict(value=b2)),
deterministic=1,
cudnn_state=np.random.randint(0, 3),
use_cudnn=use_cudnn,
use_gpu_engine=use_gpu_engine,
engine=engine)
m.net.Sum([mid_1, mid_2], top)

Expand Down Expand Up @@ -588,37 +605,41 @@ def run():
1763719461732352.0,
rtol=1e-5)

def test_use_cudnn_engine_interactions(self):
"""Make sure the use_cudnn and engine kwargs work as expected."""
def test_use_gpu_engine_interactions(self):
"""Make sure the use_gpu_engine and engine kwargs work as expected."""
for model_default in [None, True, False]:
arg_scope = {}
if model_default is not None:
arg_scope['use_cudnn'] = model_default
arg_scope['use_gpu_engine'] = model_default
else:
model_default = True # the default

model = ModelHelper(arg_scope=arg_scope)
self.assertEqual(model.arg_scope['use_cudnn'], model_default)
self.assertEqual(model.arg_scope['use_gpu_engine'], model_default)
f = functools.partial(brew.conv, model,
'conv_in', 'conv_out', 10, 10, 5)

for op_cudnn in [None, True, False]:
for op_engine in [None, '', 'CUDNN']:
for op_gpu_engine in [None, True, False]:
for op_engine in [None, '', 'MIOPEN' if workspace.has_hip_support else 'CUDNN']:
kwargs = {}
if op_cudnn is not None:
kwargs['use_cudnn'] = op_cudnn
if op_gpu_engine is not None:
kwargs['use_gpu_engine'] = op_gpu_engine
else:
op_cudnn = False # the default
op_gpu_engine = False # the default
if op_engine is not None:
kwargs['engine'] = op_engine

calculated_cudnn = kwargs.get('use_cudnn', model_default)
calculated_gpu_engine = kwargs.get('use_gpu_engine', model_default)
if calculated_gpu_engine:
expected_engine_default = 'MIOPEN' if workspace.has_hip_support else 'CUDNN'
else:
expected_engine_default = ''
expected_engine = kwargs.get(
'engine',
'CUDNN' if calculated_cudnn else '')
expected_engine_default)

if ((calculated_cudnn is True and op_engine == '') or
(calculated_cudnn is False and op_engine == 'CUDNN')):
if ((calculated_gpu_engine is True and op_engine == '') or
(calculated_cudnn is False and op_engine == ('MIOPEN' if workspace.has_hip_support else 'CUDNN'))):
with self.assertRaises(ValueError):
f(**kwargs)
else:
Expand Down
30 changes: 19 additions & 11 deletions caffe2/python/operator_test/copy_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,29 @@ def run_test_copy_gradient(self, device_opt):
def test_copy_gradient_cpu(self):
self.run_test_copy_gradient(core.DeviceOption(caffe2_pb2.CPU, 0))

@unittest.skipIf(workspace.NumCudaDevices() < 1, "Need at least 1 GPU.")
num_gpu = 0
if workspace.has_hip_support:
num_gpu = workspace.NumHipDevices()
else:
num_gpu = workspace.NumCudaDevices()

@unittest.skipIf(num_gpu < 1, "Need at least 1 GPU.")
def test_copy_gradient_gpu(self):
self.run_test_copy_gradient(core.DeviceOption(caffe2_pb2.CUDA, 0))
self.run_test_copy_gradient(core.DeviceOption(caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA, 0))

@unittest.skipIf(workspace.NumCudaDevices() < 2, "Need at least 2 GPU.")
@unittest.skipIf(num_gpu < 2, "Need at least 2 GPU.")
def test_copy_gradient_multiple_gpus(self):
model = model_helper.ModelHelper(name="copy_test")

with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU, 0)):
x_cpu = model.net.AddExternalInputs("x_cpu")

with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 0)):
gpu_device = caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA

with core.DeviceScope(core.DeviceOption(gpu_device, 0)):
x_gpu_1 = model.CopyCPUToGPU(x_cpu, "x_gpu_1")

with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 1)):
with core.DeviceScope(core.DeviceOption(gpu_device, 1)):
x_gpu_2 = model.Copy(x_gpu_1, "x_gpu_2")
loss = model.AveragedLoss(x_gpu_2, "loss")
gradient_map = model.AddGradientOperators([loss])
Expand All @@ -80,20 +88,20 @@ def get_op_with_output(model, output_blob_name):

self.assertEqual(
get_op_with_output(model, "x_gpu_2_grad").device_option,
core.DeviceOption(caffe2_pb2.CUDA, 1),
core.DeviceOption(gpu_device, 1),
)
self.assertEqual(
get_op_with_output(model, "x_cpu_grad").device_option,
core.DeviceOption(caffe2_pb2.CUDA, 0),
core.DeviceOption(gpu_device, 0),
)

@unittest.skipIf(workspace.NumCudaDevices() < 1, "Need at least 1 GPU.")
@unittest.skipIf(num_gpu < 1, "Need at least 1 GPU.")
def test_cpu2gpu_gpu2cpu_sparse_gradients(self):
model = model_helper.ModelHelper(name="copy_test")
v = model.param_init_net.UniformFill([], ["v"], shape=[16, 4])
indices = model.param_init_net.UniformFill([], ["v"], shape=[16, 4])
cpu_opt = core.DeviceOption(caffe2_pb2.CPU, 0)
gpu_opt = core.DeviceOption(caffe2_pb2.CUDA, 0)
gpu_opt = core.DeviceOption(caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA, 0)

with core.DeviceScope(gpu_opt):
vcpu = model.CopyGPUToCPU(v, "vcpu")
Expand All @@ -112,13 +120,13 @@ def test_cpu2gpu_gpu2cpu_sparse_gradients(self):
self.assertTrue("v" in gradient_map)
self.assertTrue(isinstance(gradient_map['v'], core.GradientSlice))

@unittest.skipIf(workspace.NumCudaDevices() < 1, "Need at least 1 GPU.")
@unittest.skipIf(num_gpu < 1, "Need at least 1 GPU.")
def test_cpu2gpu_gpu2cpu_gradients(self):
model = model_helper.ModelHelper(name="copy_test")

batch = 32
cpu_opt = core.DeviceOption(caffe2_pb2.CPU, 0)
gpu_opt = core.DeviceOption(caffe2_pb2.CUDA, 0)
gpu_opt = core.DeviceOption(caffe2_pb2.HIP if workspace.has_hip_support else caffe2_pb2.CUDA, 0)

with core.NameScope("cpu"):
with core.DeviceScope(cpu_opt):
Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/operator_test/elementwise_op_broadcast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def test_sum_reduce(self, gc, dc):
np.testing.assert_array_almost_equal(out, res)
self.assertDeviceChecks(dc, op, [X, Y], [0])

# fp64 is not supported with the CUDA op
dc_cpu_only = [d for d in dc if d.device_type != caffe2_pb2.CUDA]
# fp64 is not supported with the CUDA/HIP op
dc_cpu_only = [d for d in dc if (d.device_type != caffe2_pb2.CUDA or d.device_type != caffe2_pb2.HIP)]
self.assertDeviceChecks(dc_cpu_only, op, [X, Y], [0])

@unittest.skipIf(not workspace.has_gpu_support, "No gpu support")
Expand Down