Skip to content

Commit dc3ac29

Browse files
authored
Add C++ ops to torchvision (#826)
* Initial layout for layers with cpp extensions * Move files around * Fix import after move * Add support for multiple types to ROIAlign * Different organization CUDA extensions work now * Cleanups * Reduce memory requirements for backwards * Replace runtime_error by AT_ERROR * Add nms test * Add support for compilation using CPP extensions * Change folder structure * Add ROIPool cuda * Cleanups * Add roi_pool.py * Fix lint * Add initial structures folder for bounding boxes * Assertion macros compatible with pytorch master (#540) * Support for ROI Pooling (#592) * ROI Pooling with tests. Fix for cuda context in ROI Align. * renamed bottom and top to follow torch conventions * remove .type().tensor() calls in favor of the new approach to tensor initialization (#626) * Consistent naming for rois variable (#627) * remove .type().tensor() calls in favor of the new approach to tensor initialization * Consistent naming for rois variable in ROIPool * ROIPool: Support for all datatypes (#632) * Use of torch7 naming scheme for ROIAlign forward and backward * use common cuda helpers in ROIAlign * use .options() in favor of .type() where applicable * Added tests for forward pass of ROIAlign, as well as more consistent naming scheme for CPU vs CUDA * working ROIAlign cuda backwards pass * working ROIAlign backwards pass for CPU * added relevant headers for ROIAlign backwards * tests for ROIAlign layer * replace .type() with .options() for tensor initialization in ROIAlign layers * support for Half types in ROIAlign * gradcheck tests for ROIAlign * updated ROIPool on CPU to work with all datatypes * updated and cleaned tests for ROI Pooling * Fix rebase problem * Remove structures folder * Improve cleanup and bugfix in test_layers * Update C++ headers * Add CUDAGuard to cu files * Add more checks to layers * Add CUDA NMS and tests * Add multi-type support for NMS CUDA * Avoid using THCudaMalloc * Add clang-format and reformat c++ code * Remove THC includes * Rename layers to ops * Add documentation and rename functions * Improve the documentation a bit * Fix some lint errors * Fix remaining lint inssues * Area computation doesn't add +1 in NMS * Update CI to use PyTorch nightly * Make NMS return indices sorted according to the score * Address reviewer comments * Lint fixes * Improve doc for roi_align and roi_pool * move to xenial * Fix bug pointed by @lopuhin * Fix RoIPool reference implementation in Python 2 Also fixes a bug in the clip_boxes_to_image -- this function needs a test! * Remove change in .travis
1 parent 0564df4 commit dc3ac29

23 files changed

+2806
-0
lines changed

.clang-format

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
---
2+
AccessModifierOffset: -1
3+
AlignAfterOpenBracket: AlwaysBreak
4+
AlignConsecutiveAssignments: false
5+
AlignConsecutiveDeclarations: false
6+
AlignEscapedNewlinesLeft: true
7+
AlignOperands: false
8+
AlignTrailingComments: false
9+
AllowAllParametersOfDeclarationOnNextLine: false
10+
AllowShortBlocksOnASingleLine: false
11+
AllowShortCaseLabelsOnASingleLine: false
12+
AllowShortFunctionsOnASingleLine: Empty
13+
AllowShortIfStatementsOnASingleLine: false
14+
AllowShortLoopsOnASingleLine: false
15+
AlwaysBreakAfterReturnType: None
16+
AlwaysBreakBeforeMultilineStrings: true
17+
AlwaysBreakTemplateDeclarations: true
18+
BinPackArguments: false
19+
BinPackParameters: false
20+
BraceWrapping:
21+
AfterClass: false
22+
AfterControlStatement: false
23+
AfterEnum: false
24+
AfterFunction: false
25+
AfterNamespace: false
26+
AfterObjCDeclaration: false
27+
AfterStruct: false
28+
AfterUnion: false
29+
BeforeCatch: false
30+
BeforeElse: false
31+
IndentBraces: false
32+
BreakBeforeBinaryOperators: None
33+
BreakBeforeBraces: Attach
34+
BreakBeforeTernaryOperators: true
35+
BreakConstructorInitializersBeforeComma: false
36+
BreakAfterJavaFieldAnnotations: false
37+
BreakStringLiterals: false
38+
ColumnLimit: 80
39+
CommentPragmas: '^ IWYU pragma:'
40+
#CompactNamespaces: false
41+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
42+
ConstructorInitializerIndentWidth: 4
43+
ContinuationIndentWidth: 4
44+
Cpp11BracedListStyle: true
45+
DerivePointerAlignment: false
46+
DisableFormat: false
47+
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
48+
IncludeCategories:
49+
- Regex: '^<.*\.h(pp)?>'
50+
Priority: 1
51+
- Regex: '^<.*'
52+
Priority: 2
53+
- Regex: '.*'
54+
Priority: 3
55+
IndentCaseLabels: true
56+
IndentWidth: 2
57+
IndentWrappedFunctionNames: false
58+
KeepEmptyLinesAtTheStartOfBlocks: false
59+
MacroBlockBegin: ''
60+
MacroBlockEnd: ''
61+
MaxEmptyLinesToKeep: 1
62+
NamespaceIndentation: None
63+
ObjCBlockIndentWidth: 2
64+
ObjCSpaceAfterProperty: false
65+
ObjCSpaceBeforeProtocolList: false
66+
PenaltyBreakBeforeFirstCallParameter: 1
67+
PenaltyBreakComment: 300
68+
PenaltyBreakFirstLessLess: 120
69+
PenaltyBreakString: 1000
70+
PenaltyExcessCharacter: 1000000
71+
PenaltyReturnTypeOnItsOwnLine: 2000000
72+
PointerAlignment: Left
73+
ReflowComments: true
74+
SortIncludes: true
75+
SpaceAfterCStyleCast: false
76+
SpaceBeforeAssignmentOperators: true
77+
SpaceBeforeParens: ControlStatements
78+
SpaceInEmptyParentheses: false
79+
SpacesBeforeTrailingComments: 1
80+
SpacesInAngles: false
81+
SpacesInContainerLiterals: true
82+
SpacesInCStyleCastParentheses: false
83+
SpacesInParentheses: false
84+
SpacesInSquareBrackets: false
85+
Standard: Cpp11
86+
TabWidth: 8
87+
UseTab: Never
88+
...

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@ dist/
33
torchvision.egg-info/
44
torchvision/version.py
55
*/**/__pycache__
6+
*/__pycache__
7+
*/*.pyc
68
*/**/*.pyc
9+
*/**/**/*.pyc
710
*/**/*~
811
*~
912
docs/build
1013
.coverage
1114
htmlcov
1215
.*.swp
16+
*.so*
17+
*.dylib*
18+
*/*.so*
19+
*/*.dylib*
20+
*.swp
21+
*.swo

setup.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
from setuptools import setup, find_packages
77
from pkg_resources import get_distribution, DistributionNotFound
88
import subprocess
9+
import distutils.command.clean
10+
import glob
11+
import shutil
12+
13+
import torch
14+
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
915

1016

1117
def read(*names, **kwargs):
@@ -69,6 +75,55 @@ def write_version_file():
6975
requirements.append(pillow_req + pillow_ver)
7076

7177

78+
def get_extensions():
79+
this_dir = os.path.dirname(os.path.abspath(__file__))
80+
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')
81+
82+
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
83+
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
84+
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
85+
86+
sources = main_file + source_cpu
87+
extension = CppExtension
88+
89+
define_macros = []
90+
91+
if torch.cuda.is_available() and CUDA_HOME is not None:
92+
extension = CUDAExtension
93+
sources += source_cuda
94+
define_macros += [('WITH_CUDA', None)]
95+
96+
sources = [os.path.join(extensions_dir, s) for s in sources]
97+
98+
include_dirs = [extensions_dir]
99+
100+
ext_modules = [
101+
extension(
102+
'torchvision._C',
103+
sources,
104+
include_dirs=include_dirs,
105+
define_macros=define_macros,
106+
)
107+
]
108+
109+
return ext_modules
110+
111+
112+
class clean(distutils.command.clean.clean):
113+
def run(self):
114+
with open('.gitignore', 'r') as f:
115+
ignores = f.read()
116+
for wildcard in filter(None, ignores.split('\n')):
117+
for filename in glob.glob(wildcard):
118+
try:
119+
os.remove(filename)
120+
except OSError:
121+
shutil.rmtree(filename, ignore_errors=True)
122+
123+
# It's an old-style class in Python 2.7...
124+
distutils.command.clean.clean.run(self)
125+
126+
72127
setup(
73128
# Metadata
74129
name=package_name,
@@ -88,4 +143,6 @@ def write_version_file():
88143
extras_require={
89144
"scipy": ["scipy"],
90145
},
146+
ext_modules=get_extensions(),
147+
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
91148
)

0 commit comments

Comments
 (0)