Skip to content
2 changes: 1 addition & 1 deletion benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
num_gemms_list = [8]

if args.profile:
mkns = [(4096, 4096, 4096)]
mkns = [(4096 * 8, 4096, 4096)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_num_bits) {
cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/common/include/transformer_engine/swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ extern "C" {
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
439 changes: 396 additions & 43 deletions transformer_engine/common/swizzle/swizzle.cu

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/common/util/padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct MultiPaddingArgs {
int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths
int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
Expand Down
10 changes: 2 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}

// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -441,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}

// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down
19 changes: 10 additions & 9 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;

auto none = py::none();

Expand Down Expand Up @@ -396,10 +394,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue;
}

// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));

auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
Expand All @@ -419,18 +413,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());

wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B));
te_A_wrappers.emplace_back(std::move(te_A));
te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
}

// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa);
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb);

for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
}

// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,13 +841,13 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}
if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts);
}

// Convert tensors to Python
Expand Down Expand Up @@ -939,7 +939,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts);
rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
Expand All @@ -966,7 +966,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts);
columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
Expand Down
95 changes: 95 additions & 0 deletions transformer_engine/pytorch/csrc/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,98 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap

return swizzled_scale_inv;
}

std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;

if (tensors.empty()) {
return std::nullopt;
}

bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");

if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return std::nullopt;
}

std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;

// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}

// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));

for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape());

// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(),
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
}

input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}

// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());

return buffer;
}
11 changes: 9 additions & 2 deletions transformer_engine/pytorch/csrc/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@

#include "transformer_engine/transformer_engine.h"

/* Swizzle the scaling factor of the input tensor.
/*! \brief Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans);
bool rowwise);

/*! \brief Swizzle the scaling factor of the input tensors.
*
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs.
*/
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);

#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_