Skip to content

Commit f9e1f2b

Browse files
peri044narendasan
authored andcommitted
feat: Enable sparsity support in TRTorch
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2f23d6e commit f9e1f2b

File tree

11 files changed

+32
-8
lines changed

11 files changed

+32
-8
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
8989
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
9090
}
9191

92+
if (settings.sparse_weights) {
93+
cfg->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
94+
}
95+
9296
if (settings.refit) {
9397
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
9498
}

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct Device {
2525

2626
struct BuilderSettings {
2727
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
28-
std::vector<nvinfer1::DataType> input_dtypes;
28+
bool sparse_weights = false;
2929
bool disable_tf32 = false;
3030
bool refit = false;
3131
bool debug = false;

cpp/api/include/trtorch/trtorch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ struct TRTORCH_API CompileSpec {
687687
*/
688688
bool disable_tf32 = false;
689689

690+
/**
691+
* Enable sparsity for weights of conv and FC layers
692+
*/
693+
bool sparse_weights = false;
694+
690695
/**
691696
* Build a refitable engine
692697
*/

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
364364
}
365365
}
366366

367+
internal.convert_info.engine_settings.sparse_weights = external.sparse_weights;
367368
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
368369
internal.convert_info.engine_settings.refit = external.refit;
369370
internal.convert_info.engine_settings.debug = external.debug;

cpp/trtorchexec/main.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ int main(int argc, const char* argv[]) {
5757

5858
auto compile_spec = trtorch::CompileSpec(dims);
5959
compile_spec.workspace_size = 1 << 24;
60+
compile_spec.sparse_weights = true;
6061

61-
std::cout << "Checking operator support" << std::endl;
62-
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
63-
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
64-
return -1;
65-
}
62+
// std::cout << "Checking operator support" << std::endl;
63+
// if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
64+
// std::cerr << "Method is not currently supported by TRTorch" << std::endl;
65+
// return -1;
66+
// }
6667

6768
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6869
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
199199
if "calibrator" in compile_spec:
200200
info.ptq_calibrator = compile_spec["calibrator"]
201201

202+
if "sparse_weights" in compile_spec:
203+
assert isinstance(compile_spec["sparse_weights"], bool)
204+
info.sparse_weights = compile_spec["sparse_weights"]
205+
202206
if "disable_tf32" in compile_spec:
203207
assert isinstance(compile_spec["disable_tf32"], bool)
204208
info.disable_tf32 = compile_spec["disable_tf32"]
@@ -282,8 +286,8 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
282286
"dla_core": 0, # (DLA only) Target dla core id to run engine
283287
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
284288
},
285-
"op_precision": torch.half, # Operating precision set to FP16
286-
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
289+
"enabled_precisions": {torch.half}, # Operating precision set to FP16
290+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
287291
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
288292
"refit": False, # enable refit
289293
"debug": False, # enable debuggable engine

py/trtorch/_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
4444
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
4545
},
4646
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
47+
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
48+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
4749
"refit": false, # enable refit
4850
"debug": false, # enable debuggable engine
4951
"strict_types": false, # kernels should strictly run in operating precision
@@ -113,6 +115,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
113115
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
114116
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
115117
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
118+
"sparse_weights": Enable sparsity for convolution and fully connected layers.
116119
"refit": false, # enable refit
117120
"debug": false, # enable debuggable engine
118121
"strict_types": false, # kernels should strictly run in operating precision

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void RegisterTRTCompileSpec() {
4848
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
4949
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
5050

51+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
5152
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
5253
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
5354
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
183183
}
184184

185185
info.convert_info.engine_settings.calibrator = ptq_calibrator;
186+
info.convert_info.engine_settings.sparse_weights = sparse_weights;
186187
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
187188
info.convert_info.engine_settings.refit = refit;
188189
info.convert_info.engine_settings.debug = debug;
@@ -222,6 +223,7 @@ std::string CompileSpec::stringify() {
222223
}
223224
ss << " ]" << std::endl;
224225
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
226+
ss << " \"Sparsity\": " << sparse_weights << std::endl;
225227
ss << " \"Refit\": " << refit << std::endl;
226228
ss << " \"Debug\": " << debug << std::endl;
227229
ss << " \"Strict Types\": " << strict_types << std::endl;

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ struct CompileSpec : torch::CustomClassHolder {
139139
}
140140

141141
ADD_FIELD_GET_SET(disable_tf32, bool);
142+
ADD_FIELD_GET_SET(sparse_weights, bool);
142143
ADD_FIELD_GET_SET(refit, bool);
143144
ADD_FIELD_GET_SET(debug, bool);
144145
ADD_FIELD_GET_SET(strict_types, bool);
@@ -155,6 +156,7 @@ struct CompileSpec : torch::CustomClassHolder {
155156
std::vector<Input> inputs;
156157
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
157158
std::set<DataType> enabled_precisions = {DataType::kFloat};
159+
bool sparse_weights = false;
158160
bool disable_tf32 = false;
159161
bool refit = false;
160162
bool debug = false;

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ PYBIND11_MODULE(_C, m) {
255255
.def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions)
256256
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
257257
.def_readwrite("refit", &CompileSpec::refit)
258+
.def_readwrite("sparse_weights", &CompileSpec::sparse_weights)
258259
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
259260
.def_readwrite("debug", &CompileSpec::debug)
260261
.def_readwrite("strict_types", &CompileSpec::strict_types)

0 commit comments

Comments
 (0)