diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 71159eb2b5..c9a76602c2 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -124,13 +124,13 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision"); cfg->setDLACore(settings.device.dla_core); - if (settings.dla_sram_size != 1048576) { + if (settings.dla_sram_size != DLA_SRAM_SIZE) { cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_MANAGED_SRAM, settings.dla_sram_size); } - if (settings.dla_local_dram_size != 1073741824) { + if (settings.dla_local_dram_size != DLA_LOCAL_DRAM_SIZE) { cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_LOCAL_DRAM, settings.dla_local_dram_size); } - if (settings.dla_global_dram_size != 536870912) { + if (settings.dla_global_dram_size != DLA_GLOBAL_DRAM_SIZE) { cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_GLOBAL_DRAM, settings.dla_global_dram_size); } } diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index 988a58bf49..5f8d6e955b 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -35,9 +35,9 @@ struct BuilderSettings { nvinfer1::IInt8Calibrator* calibrator = nullptr; uint64_t num_avg_timing_iters = 1; uint64_t workspace_size = 0; - uint64_t dla_sram_size = 1048576; - uint64_t dla_local_dram_size = 1073741824; - uint64_t dla_global_dram_size = 536870912; + uint64_t dla_sram_size = DLA_SRAM_SIZE; + uint64_t dla_local_dram_size = DLA_LOCAL_DRAM_SIZE; + uint64_t dla_global_dram_size = DLA_GLOBAL_DRAM_SIZE; BuilderSettings() = default; BuilderSettings(const BuilderSettings& other) = default; diff --git a/core/util/macros.h b/core/util/macros.h index e58b5f0daf..c05822c975 100644 --- a/core/util/macros.h +++ b/core/util/macros.h @@ -4,6 +4,11 @@ #define GET_MACRO(_1, _2, NAME, ...) NAME +// DLA Memory related macros +#define DLA_SRAM_SIZE 1048576 +#define DLA_LOCAL_DRAM_SIZE 1073741824 +#define DLA_GLOBAL_DRAM_SIZE 536870912 + #define TORCHTRT_LOG(l, sev, msg) \ do { \ std::stringstream ss{}; \