Skip to content

Commit cccb89d

Browse files
authored
Merge pull request #71 from rohithkrn/rocaffe2-docs
Docs to build caffe2 on ROCm
2 parents 3d77f62 + 3bbb853 commit cccb89d

File tree

3 files changed

+198
-10
lines changed

3 files changed

+198
-10
lines changed

caffe2/python/convnet_benchmarks.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@
5959
"""
6060

6161
import argparse
62+
import os
6263

63-
from caffe2.python import workspace, brew, model_helper
64+
from caffe2.python import workspace, brew, model_helper, core
65+
from caffe2.proto import caffe2_pb2
6466

65-
66-
def MLP(order, cudnn_ws):
67+
def MLP(order, cudnn_ws, model_path=""):
6768
model = model_helper.ModelHelper(name="MLP")
6869
d = 256
6970
depth = 20
@@ -98,7 +99,7 @@ def MLP(order, cudnn_ws):
9899
return model, d
99100

100101

101-
def AlexNet(order, cudnn_ws):
102+
def AlexNet(order, cudnn_ws, model_path=""):
102103
my_arg_scope = {
103104
'order': order,
104105
'use_cudnn': True,
@@ -191,7 +192,7 @@ def AlexNet(order, cudnn_ws):
191192
return model, 224
192193

193194

194-
def OverFeat(order, cudnn_ws):
195+
def OverFeat(order, cudnn_ws, model_path=""):
195196
my_arg_scope = {
196197
'order': order,
197198
'use_cudnn': True,
@@ -277,7 +278,7 @@ def OverFeat(order, cudnn_ws):
277278
return model, 231
278279

279280

280-
def VGGA(order, cudnn_ws):
281+
def VGGA(order, cudnn_ws, model_path=""):
281282
my_arg_scope = {
282283
'order': order,
283284
'use_cudnn': True,
@@ -475,7 +476,7 @@ def _InceptionModule(
475476
return output
476477

477478

478-
def Inception(order, cudnn_ws):
479+
def Inception(order, cudnn_ws, model_path=""):
479480
my_arg_scope = {
480481
'order': order,
481482
'use_cudnn': True,
@@ -562,6 +563,84 @@ def Inception(order, cudnn_ws):
562563
model.net.AveragedLoss(xent, "loss")
563564
return model, 224
564565

566+
def Resnet50(order, cudnn_ws, model_path=""):
567+
if model_path == "":
568+
print("ERROR: please specify paths to init_net and predict_net protobufs for Resnet50")
569+
exit(1)
570+
device_opts = caffe2_pb2.DeviceOption()
571+
device_opts.device_type = caffe2_pb2.HIP
572+
device_opts.hip_gpu_id = 0
573+
574+
INIT_NET_PB = os.path.join(model_path, "init_net.pb")
575+
PREDICT_NET_PB = os.path.join(model_path, "predict_net.pb")
576+
init_def = caffe2_pb2.NetDef()
577+
with open(INIT_NET_PB, 'rb') as f:
578+
init_def.ParseFromString(f.read())
579+
init_def.device_option.CopyFrom(device_opts)
580+
581+
net_def = caffe2_pb2.NetDef()
582+
with open(PREDICT_NET_PB, 'rb') as f:
583+
net_def.ParseFromString(f.read())
584+
net_def.device_option.CopyFrom(device_opts)
585+
586+
init_net = core.Net(init_def)
587+
predict_net = core.Net(net_def)
588+
for op in init_net.Proto().op:
589+
op.device_option.CopyFrom(device_opts)
590+
for op in predict_net.Proto().op:
591+
op.device_option.CopyFrom(device_opts)
592+
my_arg_scope = {
593+
'order': order,
594+
}
595+
model = model_helper.ModelHelper(
596+
name="resnet50",
597+
arg_scope=my_arg_scope,
598+
)
599+
600+
model.param_init_net = init_net
601+
model.net = predict_net
602+
xent = model.net.LabelCrossEntropy(["gpu_0/softmax", "label"], "xent")
603+
model.net.AveragedLoss(xent, "loss")
604+
return model, 224
605+
606+
def Inception_v2(order, cudnn_ws, model_path=""):
607+
if model_path == "":
608+
print("ERROR: please specify paths to init_net and predict_net protobufs for Inception_v2")
609+
exit(1)
610+
device_opts = caffe2_pb2.DeviceOption()
611+
device_opts.device_type = caffe2_pb2.HIP
612+
device_opts.hip_gpu_id = 0
613+
614+
INIT_NET_PB = os.path.join(model_path, "init_net.pb")
615+
PREDICT_NET_PB = os.path.join(model_path, "predict_net.pb")
616+
init_def = caffe2_pb2.NetDef()
617+
with open(INIT_NET_PB, 'rb') as f:
618+
init_def.ParseFromString(f.read())
619+
init_def.device_option.CopyFrom(device_opts)
620+
621+
net_def = caffe2_pb2.NetDef()
622+
with open(PREDICT_NET_PB, 'rb') as f:
623+
net_def.ParseFromString(f.read())
624+
net_def.device_option.CopyFrom(device_opts)
625+
626+
init_net = core.Net(init_def)
627+
predict_net = core.Net(net_def)
628+
629+
my_arg_scope = {
630+
'order': order,
631+
}
632+
633+
model = model_helper.ModelHelper(
634+
name="GoogleNet",
635+
arg_scope=my_arg_scope,
636+
)
637+
638+
model.param_init_net = init_net
639+
model.net = predict_net
640+
xent = model.net.LabelCrossEntropy(["prob", "label"], "xent")
641+
model.net.AveragedLoss(xent, "loss")
642+
return model, 224
643+
565644

566645
def AddParameterUpdate(model):
567646
""" Simple plain SGD update -- not tuned to actually train the models """
@@ -575,7 +654,7 @@ def AddParameterUpdate(model):
575654

576655

577656
def Benchmark(model_gen, arg):
578-
model, input_size = model_gen(arg.order, arg.cudnn_ws)
657+
model, input_size = model_gen(arg.order, arg.cudnn_ws, arg.model_path)
579658
model.Proto().type = arg.net_type
580659
model.Proto().num_workers = arg.num_workers
581660

@@ -590,7 +669,7 @@ def Benchmark(model_gen, arg):
590669

591670
model.param_init_net.GaussianFill(
592671
[],
593-
"data",
672+
"gpu_0/data" if arg.model == "Resnet50" else "data",
594673
shape=input_shape,
595674
mean=0.0,
596675
std=1.0
@@ -701,6 +780,7 @@ def GetArgumentParser():
701780
parser.add_argument("--num_workers", type=int, default=2)
702781
parser.add_argument("--use-nvtx", default=False, action='store_true')
703782
parser.add_argument("--htrace_span_log_path", type=str)
783+
parser.add_argument("--model_path", type=str, default="", help="set path to init net and predict_net protobufs")
704784
return parser
705785

706786

@@ -723,5 +803,8 @@ def GetArgumentParser():
723803
'VGGA': VGGA,
724804
'Inception': Inception,
725805
'MLP': MLP,
806+
'Resnet50': Resnet50,
807+
'Inception_v2':Inception_v2
808+
726809
}
727810
Benchmark(model_map[args.model], args)

caffe2/python/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def IsOperatorWithEngine(op_type, engine):
8585
def DeviceOption(
8686
device_type,
8787
cuda_gpu_id=0,
88+
hip_gpu_id=0,
8889
random_seed=None,
8990
node_name=None,
9091
numa_node_id=None,
@@ -93,6 +94,7 @@ def DeviceOption(
9394
option = caffe2_pb2.DeviceOption()
9495
option.device_type = device_type
9596
option.cuda_gpu_id = cuda_gpu_id
97+
option.hip_gpu_id = hip_gpu_id
9698
if node_name is not None:
9799
option.node_name = node_name
98100
if random_seed is not None:
@@ -2022,8 +2024,9 @@ def DeduplicateGradientSlices(self, g, aggregator='sum'):
20222024
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
20232025
"""A convenient function to run everything on the GPU."""
20242026
device_option = caffe2_pb2.DeviceOption()
2025-
device_option.device_type = caffe2_pb2.CUDA
2027+
device_option.device_type = caffe2_pb2.CUDA if workspace.has_gpu_support else caffe2_pb2.HIP
20262028
device_option.cuda_gpu_id = gpu_id
2029+
device_option.hip_gpu_id = gpu_id
20272030
self._net.device_option.CopyFrom(device_option)
20282031
if use_cudnn:
20292032
for op in self._net.op:

rocm-docs/caffe2-build.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Caffe2: Building From Source on ROCm Platform
2+
3+
## Intro
4+
This instruction provides a starting point to build caffe2 on AMD GPUs (Caffe2 ROCm port) from source.
5+
*Note*: it is recommended to start with a clean Ubuntu 16.04 system
6+
7+
## Install docker
8+
9+
If your machine doesn't have docker installed, follow the steps [here](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce) to install docker.
10+
11+
## Install ROCm
12+
13+
Install ROCm stack following steps at [link](https://github.com/RadeonOpenCompute/ROCm/blob/master/README.md) if your machine doesn't have ROCm already.
14+
15+
Once the machine is ready with ROCm stack, there are two ways to use caffe2
16+
* Run the docker container with caffe2 installed in it.
17+
18+
* Build caffe2 from source inside a docker with all the dependencies.
19+
20+
## Launch docker container with caffe2 pre-installed
21+
```
22+
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add video rocm/caffe2:rocm1.8.2
23+
```
24+
25+
To run benchmarks, skip directly to benchmarks section of the document.
26+
27+
## Build Caffe2 from source
28+
### Pull the docker image
29+
```
30+
docker pull rocm/caffe2:unbuilt-rocm1.8.2
31+
```
32+
This docker image has all the dependencies for caffe2 pre-installed.
33+
34+
### Pull the latest caffe2 source:
35+
* Using https
36+
```
37+
git clone --recurse-submodules https://github.com/ROCmSoftwarePlatform/pytorch.git
38+
```
39+
* Using ssh
40+
```
41+
git clone --recurse-submodules [email protected]:ROCmSoftwarePlatform/pytorch.git
42+
```
43+
Navigate to repo directory
44+
```
45+
cd pytorch
46+
```
47+
48+
### Launch the docker container
49+
```
50+
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add video -v $PWD:/pytorch rocm/caffe2:unbuilt-rocm1.8.2
51+
```
52+
Navigate to pytorch directory `cd /pytorch` inside the container.
53+
54+
### Build caffe2 Project from source
55+
56+
* Run the command
57+
58+
`.jenkins/caffe2/build.sh`
59+
60+
61+
* Test the rocm-caffe2 Installation
62+
63+
Before running the tests, make sure that the required environment variables are set:
64+
```
65+
export LD_LIBRARY_PATH=/usr/local/caffe2/lib:$LD_LIBRARY_PATH
66+
export PYTHONPATH=/usr/local/caffe2/lib/python2.7/dist-packages:$PYTHONPATH
67+
```
68+
69+
Run the binaries under `/pytorch/build_caffe2/bin`
70+
71+
## Run benchmarks
72+
73+
Navigate to build directory, `cd /pytorch/build_caffe2` to run benchmarks.
74+
75+
Caffe2 benchmarking script supports the following networks.
76+
1. MLP
77+
2. AlexNet
78+
3. OverFeat
79+
4. VGGA
80+
5. Inception
81+
6. Inception_v2
82+
7. Resnet50
83+
84+
*Special case:* Inception_v2 and Resnet50 will need their corresponding protobuf files to run the benchmarks. Protobufs can be downloaded from caffe2 model zoo using the below command. Substitute model_name with `inception_v2` or `resnet50`
85+
86+
```
87+
python caffe2/python/models/download.py <model_name>
88+
```
89+
This will download the protobufs to current working directory.
90+
91+
To run benchmarks for networks MLP, AlexNet, OverFeat, VGGA, Inception, run the command replacing `<name_of_the_netwrok>` with one of the networks.
92+
93+
```
94+
python caffe2/python/convnet_benchmarks.py --batch_size 64 --model <name_of_the_network> --engine MIOPEN --layer_wise_benchmark True --net_type simple
95+
96+
```
97+
To run Inception_v2 or Resnet50, please add additional argument `--model_path` to the above command which should point to the model directories downloaded above.
98+
99+
```
100+
python caffe2/python/convnet_benchmarks.py --batch_size 64 --model <name_of_the_network> --engine MIOPEN --layer_wise_benchmark True --net_type simple --model_path <path_to_model_protobufs>
101+
102+
```

0 commit comments

Comments
 (0)