Skip to content

Commit b5db97b

Browse files
ShahriarSSfmassa
authored andcommitted
C++ Models (#728)
* Added the existing code * Added squeezenet and fixed some stuff in the other models * Wrote DenseNet and a part of InceptionV3 Going to clean and check all of the models and finish inception * Fixed some errors in the models Next step is writing inception and comparing with python code again. * Completed inception and changed models directory * Fixed and wrote some stuff * fixed maxpoool2d and avgpool2d and adaptiveavgpool2d * Fixed a few stuff Moved cmakelists to root and changed the namespace to vision and wrote weight initialization in inception * Added models namespace and changed cmakelists the project is now installable * Removed some comments * Changed style to pytorch style, added some comments and fixed some minor errors * Removed truncated normal init * Changed classes to structs and fixed a few errors * Replaced modelsimpl structs with functional wherever possible * Changed adaptive average pool from struct to function * Wrote a max_pool2d wrapper and added some comments * Replaced xavier init with kaiming init * Fixed an error in kaiming inits * Added model conversion and tests * Fixed a typo in alexnet and removed tests from cmake * Made an extension of tests and added module names to Densenet * Added python tests * Added MobileNet and GoogLeNet models * Added tests and conversions for new models and fixed a few errors * Updated Alexnet ad VGG * Updated Densenet, Squeezenet and Inception * Added ResNexts and their conversions * Added tests for ResNexts * Wrote tools nessesary to write ShuffleNet * Added ShuffleNetV2 * Fixed some errors in ShuffleNetV2 * Added conversions for shufflenetv2 * Fixed the errors in test_models.cpp * Updated setup.py * Fixed flake8 error on test_cpp_models.py * Changed view to reshape in forward of ResNet * Updated ShuffleNetV2 * Split extensions to tests and ops * Fixed test extension * Fixed image path in test_cpp_models.py * Fixed image path in test_cpp_models.py * Fixed a few things in test_cpp_models.py * Put the test models in evaluation mode * Fixed registering error in GoogLeNet * Updated setup.py * write test_cpp_models.py with unittest * Fixed a problem with pytest in test_cpp_models.py * Fixed a lint problem
1 parent 394de98 commit b5db97b

26 files changed

+2779
-0
lines changed

CMakeLists.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
cmake_minimum_required(VERSION 2.8)
2+
project(torchvision)
3+
set(CMAKE_CXX_STANDARD 11)
4+
5+
find_package(Torch REQUIRED)
6+
7+
file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h)
8+
file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h)
9+
file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)
10+
11+
add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES})
12+
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
13+
14+
add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp)
15+
target_link_libraries(convertmodels "${PROJECT_NAME}")
16+
target_link_libraries(convertmodels "${TORCH_LIBRARIES}")
17+
18+
#add_executable(testmodels test/test_models.cpp)
19+
#target_link_libraries(testmodels "${PROJECT_NAME}")
20+
#target_link_libraries(testmodels "${TORCH_LIBRARIES}")
21+
22+
install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib)
23+
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME})
24+
install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models)

setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ def get_extensions():
8989
sources = main_file + source_cpu
9090
extension = CppExtension
9191

92+
test_dir = os.path.join(this_dir, 'test')
93+
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
94+
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
95+
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))
96+
97+
test_file = [os.path.join(test_dir, s) for s in test_file]
98+
source_models = [os.path.join(models_dir, s) for s in source_models]
99+
tests = test_file + source_models
100+
92101
define_macros = []
93102

94103
extra_compile_args = {}
@@ -109,6 +118,7 @@ def get_extensions():
109118
sources = [os.path.join(extensions_dir, s) for s in sources]
110119

111120
include_dirs = [extensions_dir]
121+
tests_include_dirs = [test_dir, models_dir]
112122

113123
ext_modules = [
114124
extension(
@@ -117,6 +127,13 @@ def get_extensions():
117127
include_dirs=include_dirs,
118128
define_macros=define_macros,
119129
extra_compile_args=extra_compile_args,
130+
),
131+
extension(
132+
'torchvision._C_tests',
133+
tests,
134+
include_dirs=tests_include_dirs,
135+
define_macros=define_macros,
136+
extra_compile_args=extra_compile_args,
120137
)
121138
]
122139

test/test_cpp_models.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import os
3+
import unittest
4+
from torchvision import models, transforms, _C_tests
5+
6+
from PIL import Image
7+
import torchvision.transforms.functional as F
8+
9+
10+
def process_model(model, tensor, func, name):
11+
model.eval()
12+
traced_script_module = torch.jit.trace(model, tensor)
13+
traced_script_module.save("model.pt")
14+
15+
py_output = model.forward(tensor)
16+
cpp_output = func("model.pt", tensor)
17+
18+
assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models'
19+
20+
21+
def read_image1():
22+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
23+
image = Image.open(image_path)
24+
image = image.resize((224, 224))
25+
x = F.to_tensor(image)
26+
return x.view(1, 3, 224, 224)
27+
28+
29+
def read_image2():
30+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
31+
image = Image.open(image_path)
32+
image = image.resize((299, 299))
33+
x = F.to_tensor(image)
34+
x = x.view(1, 3, 299, 299)
35+
return torch.cat([x, x], 0)
36+
37+
38+
class Tester(unittest.TestCase):
39+
pretrained = False
40+
image = read_image1()
41+
42+
def test_alexnet(self):
43+
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet')
44+
45+
def test_vgg11(self):
46+
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11')
47+
48+
def test_vgg13(self):
49+
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13')
50+
51+
def test_vgg16(self):
52+
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16')
53+
54+
def test_vgg19(self):
55+
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19')
56+
57+
def test_vgg11_bn(self):
58+
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN')
59+
60+
def test_vgg13_bn(self):
61+
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')
62+
63+
def test_vgg16_bn(self):
64+
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN')
65+
66+
def test_vgg19_bn(self):
67+
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN')
68+
69+
def test_resnet18(self):
70+
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18')
71+
72+
def test_resnet34(self):
73+
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34')
74+
75+
def test_resnet50(self):
76+
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50')
77+
78+
def test_resnet101(self):
79+
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101')
80+
81+
def test_resnet152(self):
82+
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152')
83+
84+
def test_resnext50_32x4d(self):
85+
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d')
86+
87+
def test_resnext101_32x8d(self):
88+
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')
89+
90+
def test_squeezenet1_0(self):
91+
process_model(models.squeezenet1_0(self.pretrained), self.image,
92+
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0')
93+
94+
def test_squeezenet1_1(self):
95+
process_model(models.squeezenet1_1(self.pretrained), self.image,
96+
_C_tests.forward_squeezenet1_1, 'Squeezenet1.1')
97+
98+
def test_densenet121(self):
99+
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121')
100+
101+
def test_densenet169(self):
102+
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169')
103+
104+
def test_densenet201(self):
105+
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201')
106+
107+
def test_densenet161(self):
108+
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161')
109+
110+
def test_mobilenet_v2(self):
111+
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet')
112+
113+
def test_googlenet(self):
114+
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet')
115+
116+
def test_inception_v3(self):
117+
self.image = read_image2()
118+
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3')
119+
120+
121+
if __name__ == '__main__':
122+
unittest.main()

test/test_models.cpp

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include <torch/script.h>
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
5+
#include "../torchvision/csrc/models/models.h"
6+
7+
using namespace vision::models;
8+
9+
template <typename Model>
10+
torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) {
11+
Model network;
12+
torch::load(network, input_path);
13+
network->eval();
14+
return network->forward(x);
15+
}
16+
17+
torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) {
18+
return forward_model<AlexNet>(input_path, x);
19+
}
20+
21+
torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) {
22+
return forward_model<VGG11>(input_path, x);
23+
}
24+
torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) {
25+
return forward_model<VGG13>(input_path, x);
26+
}
27+
torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) {
28+
return forward_model<VGG16>(input_path, x);
29+
}
30+
torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) {
31+
return forward_model<VGG19>(input_path, x);
32+
}
33+
34+
torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) {
35+
return forward_model<VGG11BN>(input_path, x);
36+
}
37+
torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) {
38+
return forward_model<VGG13BN>(input_path, x);
39+
}
40+
torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) {
41+
return forward_model<VGG16BN>(input_path, x);
42+
}
43+
torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) {
44+
return forward_model<VGG19BN>(input_path, x);
45+
}
46+
47+
torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) {
48+
return forward_model<ResNet18>(input_path, x);
49+
}
50+
torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) {
51+
return forward_model<ResNet34>(input_path, x);
52+
}
53+
torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) {
54+
return forward_model<ResNet50>(input_path, x);
55+
}
56+
torch::Tensor forward_resnet101(
57+
const std::string& input_path,
58+
torch::Tensor x) {
59+
return forward_model<ResNet101>(input_path, x);
60+
}
61+
torch::Tensor forward_resnet152(
62+
const std::string& input_path,
63+
torch::Tensor x) {
64+
return forward_model<ResNet152>(input_path, x);
65+
}
66+
torch::Tensor forward_resnext50_32x4d(
67+
const std::string& input_path,
68+
torch::Tensor x) {
69+
return forward_model<ResNext50_32x4d>(input_path, x);
70+
}
71+
torch::Tensor forward_resnext101_32x8d(
72+
const std::string& input_path,
73+
torch::Tensor x) {
74+
return forward_model<ResNext101_32x8d>(input_path, x);
75+
}
76+
77+
torch::Tensor forward_squeezenet1_0(
78+
const std::string& input_path,
79+
torch::Tensor x) {
80+
return forward_model<SqueezeNet1_0>(input_path, x);
81+
}
82+
torch::Tensor forward_squeezenet1_1(
83+
const std::string& input_path,
84+
torch::Tensor x) {
85+
return forward_model<SqueezeNet1_1>(input_path, x);
86+
}
87+
88+
torch::Tensor forward_densenet121(
89+
const std::string& input_path,
90+
torch::Tensor x) {
91+
return forward_model<DenseNet121>(input_path, x);
92+
}
93+
torch::Tensor forward_densenet169(
94+
const std::string& input_path,
95+
torch::Tensor x) {
96+
return forward_model<DenseNet169>(input_path, x);
97+
}
98+
torch::Tensor forward_densenet201(
99+
const std::string& input_path,
100+
torch::Tensor x) {
101+
return forward_model<DenseNet201>(input_path, x);
102+
}
103+
torch::Tensor forward_densenet161(
104+
const std::string& input_path,
105+
torch::Tensor x) {
106+
return forward_model<DenseNet161>(input_path, x);
107+
}
108+
109+
torch::Tensor forward_mobilenetv2(
110+
const std::string& input_path,
111+
torch::Tensor x) {
112+
return forward_model<MobileNetV2>(input_path, x);
113+
}
114+
115+
torch::Tensor forward_googlenet(
116+
const std::string& input_path,
117+
torch::Tensor x) {
118+
GoogLeNet network;
119+
torch::load(network, input_path);
120+
network->eval();
121+
return network->forward(x).output;
122+
}
123+
torch::Tensor forward_inceptionv3(
124+
const std::string& input_path,
125+
torch::Tensor x) {
126+
InceptionV3 network;
127+
torch::load(network, input_path);
128+
network->eval();
129+
return network->forward(x).output;
130+
}
131+
132+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
133+
m.def("forward_alexnet", &forward_alexnet, "forward_alexnet");
134+
135+
m.def("forward_vgg11", &forward_vgg11, "forward_vgg11");
136+
m.def("forward_vgg13", &forward_vgg13, "forward_vgg13");
137+
m.def("forward_vgg16", &forward_vgg16, "forward_vgg16");
138+
m.def("forward_vgg19", &forward_vgg19, "forward_vgg19");
139+
140+
m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn");
141+
m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn");
142+
m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn");
143+
m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn");
144+
145+
m.def("forward_resnet18", &forward_resnet18, "forward_resnet18");
146+
m.def("forward_resnet34", &forward_resnet34, "forward_resnet34");
147+
m.def("forward_resnet50", &forward_resnet50, "forward_resnet50");
148+
m.def("forward_resnet101", &forward_resnet101, "forward_resnet101");
149+
m.def("forward_resnet152", &forward_resnet152, "forward_resnet152");
150+
m.def(
151+
"forward_resnext50_32x4d",
152+
&forward_resnext50_32x4d,
153+
"forward_resnext50_32x4d");
154+
m.def(
155+
"forward_resnext101_32x8d",
156+
&forward_resnext101_32x8d,
157+
"forward_resnext101_32x8d");
158+
159+
m.def(
160+
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
161+
m.def(
162+
"forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1");
163+
164+
m.def("forward_densenet121", &forward_densenet121, "forward_densenet121");
165+
m.def("forward_densenet169", &forward_densenet169, "forward_densenet169");
166+
m.def("forward_densenet201", &forward_densenet201, "forward_densenet201");
167+
m.def("forward_densenet161", &forward_densenet161, "forward_densenet161");
168+
169+
m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2");
170+
171+
m.def("forward_googlenet", &forward_googlenet, "forward_googlenet");
172+
m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3");
173+
}

0 commit comments

Comments
 (0)