From 53c93d873b71ff714e1413e90b74957eb5194d58 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Aug 2018 11:05:10 -0700 Subject: [PATCH 1/3] docs for building caffe2 --- rocm-docs/caffe2-build.md | 103 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 rocm-docs/caffe2-build.md diff --git a/rocm-docs/caffe2-build.md b/rocm-docs/caffe2-build.md new file mode 100644 index 00000000000000..4c188700af5fbd --- /dev/null +++ b/rocm-docs/caffe2-build.md @@ -0,0 +1,103 @@ +# Caffe2: Building From Source on ROCm Platform + +## Intro +This instruction provides a starting point to build caffe2 on AMD GPUs (Caffe2 ROCm port) from source. +*Note*: it is recommended to start with a clean Ubuntu 16.04 system + +## Install docker + + If your machine doesn't have docker installed, follow the steps [here](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce) to install docker. + +## Install ROCm + +Install ROCm stack following steps at [link](https://github.com/RadeonOpenCompute/ROCm/blob/master/README.md) if your machine doesn't have ROCm already. + +Once the machine is ready with ROCm stack, there are two ways to use caffe2 +* Run the docker container with caffe2 installed in it. + +* Build caffe2 from source inside a docker with all the dependencies. + +## Launch docker container with caffe2 pre-installed +``` +docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add video rocm/caffe2:rocm1.8.2 +``` + +To run benchmarks, skip directly to benchmarks section of the document. + +## Build Caffe2 from source +### Pull the docker image +``` +docker pull rocm/caffe2:unbuilt-rocm1.8.2 +``` +This docker image has all the dependencies for caffe2 pre-installed. + +### Pull the latest caffe2 source: +* Using https +``` +git clone --recurse-submodules https://github.com/ROCmSoftwarePlatform/pytorch.git +``` +* Using ssh +``` +git clone --recurse-submodules git@github.com:ROCmSoftwarePlatform/pytorch.git +``` +Navigate to repo directory +``` +cd pytorch +``` + +### Launch the docker container +``` +docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add video -v $PWD:/pytorch rocm/caffe2:unbuilt-rocm1.8.2 +``` +Navigate to pytorch directory `cd /pytorch` inside the container. + +### Build caffe2 Project from source + +* Run the command + + `.jenkins/caffe2/build.sh` + + +* Test the rocm-caffe2 Installation + + Before running the tests, make sure that the required environment variables are set: + ``` + export LD_LIBRARY_PATH=/usr/local/caffe2/lib:$LD_LIBRARY_PATH + export PYTHONPATH=/usr/local/caffe2/lib/python2.7/dist-packages:$PYTHONPATH + ``` + + Run the binaries under `/pytorch/build_caffe2/bin` + +## Run benchmarks + +Navigate to build directory, `cd /pytorch/build_caffe2` to run benchmarks. + +Caffe2 benchmarking script supports the following networks. +1. MLP +2. AlexNet +3. OverFeat +4. VGGA +5. Inception +6. Inception_v2 +7. Resnet50 + +*Special case:* Inception_v2 and Resnet50 will need their corresponding protobuf files to run the benchmarks. Protobufs can be downloaded from caffe2 model zoo using the below command. Substitute model_name with `inception_v2` or `resnet50` + +``` +python caffe2/python/models/download.py +``` +This will download the protobufs to current working directory. + +To run benchmarks for networks MLP, AlexNet, OverFeat, VGGA, Inception, run the command replacing `` with one of the networks. + +``` +python caffe2/python/convnet_benchmarks.py --batch_size 64 --model --engine MIOPEN --layer_wise_benchmark True --net_type simple + +``` +To run Inception_v2 or Resnet50, please add additional argument `--model_path` to the above command which should point to the model directories downloaded above. + +``` +python caffe2/python/convnet_benchmarks.py --batch_size 64 --model --engine MIOPEN --layer_wise_benchmark True --net_type simple --model_path + +``` + From 89a268e7a288c7edf4d734e7236edc1d5a4d57bc Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Aug 2018 11:05:30 -0700 Subject: [PATCH 2/3] update benchmarks script --- caffe2/python/convnet_benchmarks.py | 101 +++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 9 deletions(-) diff --git a/caffe2/python/convnet_benchmarks.py b/caffe2/python/convnet_benchmarks.py index 3aac78c18df16b..6b8af12f5cec71 100644 --- a/caffe2/python/convnet_benchmarks.py +++ b/caffe2/python/convnet_benchmarks.py @@ -59,11 +59,12 @@ """ import argparse +import os -from caffe2.python import workspace, brew, model_helper +from caffe2.python import workspace, brew, model_helper, core +from caffe2.proto import caffe2_pb2 - -def MLP(order, cudnn_ws): +def MLP(order, cudnn_ws, model_path=""): model = model_helper.ModelHelper(name="MLP") d = 256 depth = 20 @@ -98,7 +99,7 @@ def MLP(order, cudnn_ws): return model, d -def AlexNet(order, cudnn_ws): +def AlexNet(order, cudnn_ws, model_path=""): my_arg_scope = { 'order': order, 'use_cudnn': True, @@ -191,7 +192,7 @@ def AlexNet(order, cudnn_ws): return model, 224 -def OverFeat(order, cudnn_ws): +def OverFeat(order, cudnn_ws, model_path=""): my_arg_scope = { 'order': order, 'use_cudnn': True, @@ -277,7 +278,7 @@ def OverFeat(order, cudnn_ws): return model, 231 -def VGGA(order, cudnn_ws): +def VGGA(order, cudnn_ws, model_path=""): my_arg_scope = { 'order': order, 'use_cudnn': True, @@ -475,7 +476,7 @@ def _InceptionModule( return output -def Inception(order, cudnn_ws): +def Inception(order, cudnn_ws, model_path=""): my_arg_scope = { 'order': order, 'use_cudnn': True, @@ -562,6 +563,84 @@ def Inception(order, cudnn_ws): model.net.AveragedLoss(xent, "loss") return model, 224 +def Resnet50(order, cudnn_ws, model_path=""): + if model_path == "": + print("ERROR: please specify paths to init_net and predict_net protobufs for Resnet50") + exit(1) + device_opts = caffe2_pb2.DeviceOption() + device_opts.device_type = caffe2_pb2.HIP + device_opts.hip_gpu_id = 0 + + INIT_NET_PB = os.path.join(model_path, "init_net.pb") + PREDICT_NET_PB = os.path.join(model_path, "predict_net.pb") + init_def = caffe2_pb2.NetDef() + with open(INIT_NET_PB, 'rb') as f: + init_def.ParseFromString(f.read()) + init_def.device_option.CopyFrom(device_opts) + + net_def = caffe2_pb2.NetDef() + with open(PREDICT_NET_PB, 'rb') as f: + net_def.ParseFromString(f.read()) + net_def.device_option.CopyFrom(device_opts) + + init_net = core.Net(init_def) + predict_net = core.Net(net_def) + for op in init_net.Proto().op: + op.device_option.CopyFrom(device_opts) + for op in predict_net.Proto().op: + op.device_option.CopyFrom(device_opts) + my_arg_scope = { + 'order': order, + } + model = model_helper.ModelHelper( + name="resnet50", + arg_scope=my_arg_scope, + ) + + model.param_init_net = init_net + model.net = predict_net + xent = model.net.LabelCrossEntropy(["gpu_0/softmax", "label"], "xent") + model.net.AveragedLoss(xent, "loss") + return model, 224 + +def Inception_v2(order, cudnn_ws, model_path=""): + if model_path == "": + print("ERROR: please specify paths to init_net and predict_net protobufs for Inception_v2") + exit(1) + device_opts = caffe2_pb2.DeviceOption() + device_opts.device_type = caffe2_pb2.HIP + device_opts.hip_gpu_id = 0 + + INIT_NET_PB = os.path.join(model_path, "init_net.pb") + PREDICT_NET_PB = os.path.join(model_path, "predict_net.pb") + init_def = caffe2_pb2.NetDef() + with open(INIT_NET_PB, 'rb') as f: + init_def.ParseFromString(f.read()) + init_def.device_option.CopyFrom(device_opts) + + net_def = caffe2_pb2.NetDef() + with open(PREDICT_NET_PB, 'rb') as f: + net_def.ParseFromString(f.read()) + net_def.device_option.CopyFrom(device_opts) + + init_net = core.Net(init_def) + predict_net = core.Net(net_def) + + my_arg_scope = { + 'order': order, + } + + model = model_helper.ModelHelper( + name="GoogleNet", + arg_scope=my_arg_scope, + ) + + model.param_init_net = init_net + model.net = predict_net + xent = model.net.LabelCrossEntropy(["prob", "label"], "xent") + model.net.AveragedLoss(xent, "loss") + return model, 224 + def AddParameterUpdate(model): """ Simple plain SGD update -- not tuned to actually train the models """ @@ -575,7 +654,7 @@ def AddParameterUpdate(model): def Benchmark(model_gen, arg): - model, input_size = model_gen(arg.order, arg.cudnn_ws) + model, input_size = model_gen(arg.order, arg.cudnn_ws, arg.model_path) model.Proto().type = arg.net_type model.Proto().num_workers = arg.num_workers @@ -590,7 +669,7 @@ def Benchmark(model_gen, arg): model.param_init_net.GaussianFill( [], - "data", + "gpu_0/data" if arg.model == "Resnet50" else "data", shape=input_shape, mean=0.0, std=1.0 @@ -701,6 +780,7 @@ def GetArgumentParser(): parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--use-nvtx", default=False, action='store_true') parser.add_argument("--htrace_span_log_path", type=str) + parser.add_argument("--model_path", type=str, default="", help="set path to init net and predict_net protobufs") return parser @@ -723,5 +803,8 @@ def GetArgumentParser(): 'VGGA': VGGA, 'Inception': Inception, 'MLP': MLP, + 'Resnet50': Resnet50, + 'Inception_v2':Inception_v2 + } Benchmark(model_map[args.model], args) From 00938a90a4dcfb4361d6410ac422f359cea6acb9 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Aug 2018 11:05:50 -0700 Subject: [PATCH 3/3] add hip device support --- caffe2/python/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 1423cfea3a6c09..7594491e092900 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -85,6 +85,7 @@ def IsOperatorWithEngine(op_type, engine): def DeviceOption( device_type, cuda_gpu_id=0, + hip_gpu_id=0, random_seed=None, node_name=None, numa_node_id=None, @@ -93,6 +94,7 @@ def DeviceOption( option = caffe2_pb2.DeviceOption() option.device_type = device_type option.cuda_gpu_id = cuda_gpu_id + option.hip_gpu_id = hip_gpu_id if node_name is not None: option.node_name = node_name if random_seed is not None: @@ -2022,8 +2024,9 @@ def DeduplicateGradientSlices(self, g, aggregator='sum'): def RunAllOnGPU(self, gpu_id=0, use_cudnn=False): """A convenient function to run everything on the GPU.""" device_option = caffe2_pb2.DeviceOption() - device_option.device_type = caffe2_pb2.CUDA + device_option.device_type = caffe2_pb2.CUDA if workspace.has_gpu_support else caffe2_pb2.HIP device_option.cuda_gpu_id = gpu_id + device_option.hip_gpu_id = gpu_id self._net.device_option.CopyFrom(device_option) if use_cudnn: for op in self._net.op: