@@ -134,8 +134,9 @@ def get_extensions():
134
134
this_dir = os .path .dirname (os .path .abspath (__file__ ))
135
135
extensions_dir = os .path .join (this_dir , 'torchvision' , 'csrc' )
136
136
137
- main_file = glob .glob (os .path .join (extensions_dir , '*.cpp' ))
138
- source_cpu = glob .glob (os .path .join (extensions_dir , 'cpu' , '*.cpp' ))
137
+ main_file = glob .glob (os .path .join (extensions_dir , '*.cpp' )) + glob .glob (os .path .join (extensions_dir , 'ops' ,
138
+ '*.cpp' ))
139
+ source_cpu = glob .glob (os .path .join (extensions_dir , 'ops' , 'cpu' , '*.cpp' ))
139
140
140
141
is_rocm_pytorch = False
141
142
if torch .__version__ >= '1.5' :
@@ -146,17 +147,19 @@ def get_extensions():
146
147
hipify_python .hipify (
147
148
project_directory = this_dir ,
148
149
output_directory = this_dir ,
149
- includes = "torchvision/csrc/cuda/*" ,
150
+ includes = "torchvision/csrc/ops/ cuda/*" ,
150
151
show_detailed = True ,
151
152
is_pytorch_extension = True ,
152
153
)
153
- source_cuda = glob .glob (os .path .join (extensions_dir , 'hip' , '*.hip' ))
154
+ source_cuda = glob .glob (os .path .join (extensions_dir , 'ops' , ' hip' , '*.hip' ))
154
155
# Copy over additional files
155
- for file in glob .glob (r"torchvision/csrc/cuda/*.h" ):
156
- shutil .copy (file , "torchvision/csrc/hip" )
156
+ for file in glob .glob (r"torchvision/csrc/ops/ cuda/*.h" ):
157
+ shutil .copy (file , "torchvision/csrc/ops/ hip" )
157
158
158
159
else :
159
- source_cuda = glob .glob (os .path .join (extensions_dir , 'cuda' , '*.cu' ))
160
+ source_cuda = glob .glob (os .path .join (extensions_dir , 'ops' , 'cuda' , '*.cu' ))
161
+
162
+ source_cuda += glob .glob (os .path .join (extensions_dir , 'ops' , 'autocast' , '*.cpp' ))
160
163
161
164
sources = main_file + source_cpu
162
165
extension = CppExtension
@@ -309,8 +312,8 @@ def get_extensions():
309
312
image_library += [jpeg_lib ]
310
313
image_include += [jpeg_include ]
311
314
312
- image_path = os .path .join (extensions_dir , 'cpu ' , 'image' )
313
- image_src = glob .glob (os .path .join (image_path , '*.cpp' ))
315
+ image_path = os .path .join (extensions_dir , 'io ' , 'image' )
316
+ image_src = glob .glob (os .path .join (image_path , '*.cpp' )) + glob . glob ( os . path . join ( image_path , 'cpu' , '*.cpp' ))
314
317
315
318
if png_found or jpeg_found :
316
319
ext_modules .append (extension (
@@ -377,13 +380,13 @@ def get_extensions():
377
380
print ("ffmpeg library_dir: {}" .format (ffmpeg_library_dir ))
378
381
379
382
# TorchVision base decoder + video reader
380
- video_reader_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'cpu ' , 'video_reader' )
383
+ video_reader_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'io ' , 'video_reader' )
381
384
video_reader_src = glob .glob (os .path .join (video_reader_src_dir , "*.cpp" ))
382
- base_decoder_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'cpu ' , 'decoder' )
385
+ base_decoder_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'io ' , 'decoder' )
383
386
base_decoder_src = glob .glob (
384
387
os .path .join (base_decoder_src_dir , "*.cpp" ))
385
388
# Torchvision video API
386
- videoapi_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'cpu ' , 'video' )
389
+ videoapi_src_dir = os .path .join (this_dir , 'torchvision' , 'csrc' , 'io ' , 'video' )
387
390
videoapi_src = glob .glob (os .path .join (videoapi_src_dir , "*.cpp" ))
388
391
# exclude tests
389
392
base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x ]
0 commit comments