Skip to content

Commit 536983b

Browse files
committed
feat(disable_tf32): Add a new API to disable TF32
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 1660633 commit 536983b

File tree

13 files changed

+44
-2
lines changed

13 files changed

+44
-2
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace conversion {
1212
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1313
os << "Settings requested for TensorRT engine:" \
1414
<< "\n Operating Precision: " << s.op_precision \
15+
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
1516
<< "\n Make Refittable Engine: " << s.refit \
1617
<< "\n Debuggable Engine: " << s.debug \
1718
<< "\n Strict Types: " << s.strict_types \
@@ -77,6 +78,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7778
}
7879
op_precision = settings.op_precision;
7980

81+
if (settings.disable_tf32) {
82+
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
83+
}
84+
8085
if (settings.refit) {
8186
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
8287
}

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Device {
2424

2525
struct BuilderSettings {
2626
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
27+
bool disable_tf32 = false;
2728
bool refit = false;
2829
bool debug = false;
2930
bool strict_types = false;

cpp/api/include/trtorch/trtorch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ struct TRTORCH_API CompileSpec {
239239
*/
240240
DataType op_precision = DataType::kFloat;
241241

242+
/**
243+
* Prevent Float32 layers from using TF32 data format
244+
*
245+
* TF32 computes inner products by rounding the inputs to 10-bit mantissas
246+
* before multiplying, but accumulates the sum using 23-bit mantissas.
247+
* This is the behavior of FP32 layers by default.
248+
*/
249+
bool disable_tf32 = false;
250+
242251
/**
243252
* Build a refitable engine
244253
*/

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
8989
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
9090
}
9191

92+
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
9293
internal.convert_info.engine_settings.refit = external.refit;
9394
internal.convert_info.engine_settings.debug = external.debug;
9495
internal.convert_info.engine_settings.strict_types = external.strict_types;

cpp/trtorchc/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ int main(int argc, char** argv) {
163163
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
164164
{"allow-gpu-fallback"});
165165

166+
args::Flag disable_tf32(
167+
parser,
168+
"disable-tf32",
169+
"Prevent Float32 layers from using the TF32 data format",
170+
{"disable-tf32"});
171+
166172
args::ValueFlag<std::string> op_precision(
167173
parser,
168174
"precision",
@@ -263,6 +269,10 @@ int main(int argc, char** argv) {
263269
compile_settings.device.allow_gpu_fallback = true;
264270
}
265271

272+
if (disable_tf32) {
273+
compile_settings.disable_tf32 = true;
274+
}
275+
266276
std::string calibration_cache_file_path = "";
267277
if (calibration_cache_file) {
268278
calibration_cache_file_path = resolve_path(args::get(calibration_cache_file));

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
135135
if "op_precision" in compile_spec:
136136
info.op_precision = _parse_op_precision(compile_spec["op_precision"])
137137

138+
if "disable_tf32" in compile_spec:
139+
assert isinstance(compile_spec["disable_tf32"], bool)
140+
info.disable_tf32 = compile_spec["disable_tf32"]
141+
138142
if "refit" in compile_spec:
139143
assert isinstance(compile_spec["refit"], bool)
140144
info.refit = compile_spec["refit"]
@@ -201,6 +205,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
201205
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
202206
},
203207
"op_precision": torch.half, # Operating precision set to FP16
208+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
204209
"refit": False, # enable refit
205210
"debug": False, # enable debuggable engine
206211
"strict_types": False, # kernels should strictly run in operating precision
@@ -239,6 +244,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
239244

240245
backend_spec.set_device(d)
241246
backend_spec.set_op_precision(int(parsed_spec.op_precision))
247+
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
242248
backend_spec.set_refit(parsed_spec.refit)
243249
backend_spec.set_debug(parsed_spec.debug)
244250
backend_spec.set_refit(parsed_spec.refit)

py/trtorch/_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
9999
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
100100
},
101101
"op_precision": torch.half, # Operating precision set to FP16
102+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
102103
"refit": false, # enable refit
103104
"debug": false, # enable debuggable engine
104105
"strict_types": false, # kernels should strictly run in operating precision

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void RegisterTRTCompileSpec() {
3232
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
3333

3434
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
35+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
3536
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
3637
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
3738
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
9999
}
100100
auto info = core::CompileSpec(internal_input_ranges);
101101
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
102+
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
102103
info.convert_info.engine_settings.refit = refit;
103104
info.convert_info.engine_settings.debug = debug;
104105
info.convert_info.engine_settings.strict_types = strict_types;
@@ -128,6 +129,7 @@ std::string CompileSpec::stringify() {
128129
}
129130
ss << " ]" << std::endl;
130131
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
132+
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
131133
ss << " \"Refit\": " << refit << std::endl;
132134
ss << " \"Debug\": " << debug << std::endl;
133135
ss << " \"Strict Types\": " << strict_types << std::endl;

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct CompileSpec : torch::CustomClassHolder {
9999
}
100100

101101
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102+
ADD_FIELD_GET_SET(disable_tf32, bool);
102103
ADD_FIELD_GET_SET(refit, bool);
103104
ADD_FIELD_GET_SET(debug, bool);
104105
ADD_FIELD_GET_SET(strict_types, bool);
@@ -111,6 +112,7 @@ struct CompileSpec : torch::CustomClassHolder {
111112

112113
std::vector<InputRange> input_ranges;
113114
DataType op_precision = DataType::kFloat;
115+
bool disable_tf32 = false;
114116
bool refit = false;
115117
bool debug = false;
116118
bool strict_types = false;

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ PYBIND11_MODULE(_C, m) {
103103
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
104104
.def_readwrite("op_precision", &CompileSpec::op_precision)
105105
.def_readwrite("refit", &CompileSpec::refit)
106+
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
106107
.def_readwrite("debug", &CompileSpec::debug)
107108
.def_readwrite("strict_types", &CompileSpec::strict_types)
108109
.def_readwrite("device", &CompileSpec::device)

tests/py/test_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_compile_traced(self):
2020
"device_type": trtorch.DeviceType.GPU,
2121
"gpu_id": 0,
2222
"dla_core": 0,
23-
"allow_gpu_fallback": False
23+
"allow_gpu_fallback": False,
24+
"disable_tf32": False
2425
}
2526
}
2627

@@ -35,7 +36,8 @@ def test_compile_script(self):
3536
"device_type": trtorch.DeviceType.GPU,
3637
"gpu_id": 0,
3738
"dla_core": 0,
38-
"allow_gpu_fallback": False
39+
"allow_gpu_fallback": False,
40+
"disable_tf32": False
3941
}
4042
}
4143

tests/py/test_to_backend_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def setUp(self):
2929
"num_min_timing_iters": 2,
3030
"num_avg_timing_iters": 1,
3131
"max_batch_size": 0,
32+
"disable_tf32": False,
3233
})
3334
}
3435

0 commit comments

Comments
 (0)