Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1d7776f
Beginning of work to properly reuse the output given to quantize
ptrendx Apr 8, 2025
d207cea
Add current scaling
ptrendx Apr 9, 2025
fa28e49
Beginning of the other recipes
ptrendx Apr 14, 2025
49ad122
Added MXFP8 and cleanup
ptrendx Apr 26, 2025
17678d9
Fix
ptrendx May 30, 2025
b488101
Actually reuse tensors and get rid of the hack for MXFP8
ptrendx May 30, 2025
209cb9f
Small cleaning
ptrendx May 30, 2025
c61b14b
Make sure dgrad is not needed in the test during eval phase
ptrendx May 31, 2025
7561fb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2025
41b8fb4
Fixes
ptrendx Jun 2, 2025
20363a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2025
2acb07a
Fixes
ptrendx Jun 6, 2025
9550644
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2025
2f9f9ce
Merge branch 'main' into pr_quantize_output_respect_usages
ptrendx Jun 6, 2025
f0f96b9
Fix for integer overflow
ptrendx Jun 6, 2025
eb49987
Try copying the quantizer
ptrendx Jun 11, 2025
6dcd480
Fix
ptrendx Jun 11, 2025
b6f1aeb
Fix CUDA graphs test
ptrendx Jun 11, 2025
b92f3f5
Merge branch 'main' into pr_quantize_output_respect_usages
ptrendx Jun 12, 2025
1f4f894
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
53554f2
Fix
ptrendx Jun 12, 2025
343d43d
Fix the float8blockwise tests and MXFP8 cuda graphs tests
ptrendx Jun 13, 2025
817d8ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2025
b6b4af3
Merge branch 'main' into pr_quantize_output_respect_usages
ptrendx Jun 13, 2025
207e4b7
Fix issue from merge
ptrendx Jun 13, 2025
715cc53
Always use tex.quantize when updating cache to use proper quantizer
ptrendx Jun 13, 2025
d682178
Debug
ptrendx Jun 17, 2025
e6f38d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/test_float8blockwisetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_quantize_dequantize_dims(
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("block_scaling_dim", [1])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.xfail(raises=NotImplementedError)
def test_quantize_dequantize_compact_format(
Expand Down
109 changes: 100 additions & 9 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def is_fp8_supported(self):

model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"small": ModelConfig(2, 16, 2, 128, 1),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
}
Expand Down Expand Up @@ -398,6 +398,34 @@ def _test_sanity_common(
loss.backward()
torch.cuda.synchronize()

# now try eval with weight caching
block.eval()
te_inp.requires_grad = False

with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp, is_first_microbatch=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp, is_first_microbatch=False)
torch.cuda.synchronize()

# now try regular execution again with weight caching
block.train()
te_inp.requires_grad = True

with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp, is_first_microbatch=True)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()


def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
Expand Down Expand Up @@ -1124,27 +1152,31 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
print("regular")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
print("checkpointed")
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# outputs_checkpoint_v1_6 = _run_attention_extra_state(
# dtype, config, mimic_v1_6=True, checkpoint=True
# )

# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
print(i)
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
print(f"Second loop {i}")
torch.testing.assert_close(
test,
ref,
Expand Down Expand Up @@ -1173,6 +1205,8 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
requires_grad=True,
)

torch.set_printoptions(threshold=100_000_000)

def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
Expand All @@ -1191,15 +1225,59 @@ def get_model(dtype, config):
params_dtype=dtype,
device="cuda",
)
# block = torch.nn.Sequential(
# Linear(config.hidden_size,
# config.hidden_size),
# Linear(config.hidden_size,
# config.hidden_size),
# Linear(config.hidden_size,
# config.hidden_size),
# Linear(config.hidden_size,
# config.hidden_size))
# block.to(dtype=dtype)
return block

block = get_model(dtype, config)
print("Before the first loop")
# for n,p in block.named_parameters():
# print(n)
# print(p)
print("data")
print(block.self_attention.proj.weight._data)
print("scale_inv")
print(block.self_attention.proj.weight._scale_inv)
print("transpose")
print(block.self_attention.proj.weight._transpose)

import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb

bbb.DEBUG_BLOCK = block
print("set!")

print("End before the first loop")
print(f"scale inv: {block.self_attention.proj.weight._scale_inv}")
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
print(f"scale inv 0: {block.self_attention.proj.weight._scale_inv}")
output = block(hidden_states)
print(f"scale inv 1: {block.self_attention.proj.weight._scale_inv}")
print(f"output {i}")
print(output)
loss = output.sum()
loss.backward()

loss.backward()
print(f"scale inv 2: {block.self_attention.proj.weight._scale_inv}")

print("Before the checkpoint")
# for n,p in block.named_parameters():
# print(n)
# print(p)
print("data")
print(block.self_attention.proj.weight._data)
print("scale_inv")
print(block.self_attention.proj.weight._scale_inv)
print("transpose")
print(block.self_attention.proj.weight._transpose)
print("End before the checkpoint")
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
Expand Down Expand Up @@ -1231,10 +1309,23 @@ def get_model(dtype, config):

for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
output = block(hidden_states)
print(f"after output {i}")
print(output)
loss = output.sum()
loss.backward()

print("After the checkpoint")
# for n,p in block.named_parameters():
# print(n)
# print(p)
print("data")
print(block.self_attention.proj.weight._data)
print("scale_inv")
print(block.self_attention.proj.weight._scale_inv)
print("transpose")
print(block.self_attention.proj.weight._transpose)
print("End after the checkpoint")
torch.cuda.synchronize()

if os.path.exists(path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
AttentionLogging as attn_log,
)

DEBUG_BLOCK = None

# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
flash_attn_func = None
Expand Down Expand Up @@ -964,6 +966,8 @@ def forward(
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
print(f"Q quantizer scale: {q_fp8._quantizer.scale.shape}")
print(f"mixed quantizer scale: {qkv_fp8._quantizer.scale.shape}")
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
Expand Down Expand Up @@ -1190,6 +1194,11 @@ def backward(ctx, d_out):
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
print(DEBUG_BLOCK)
if DEBUG_BLOCK is not None:
print(
f"Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}"
)
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
Expand Down Expand Up @@ -1218,6 +1227,11 @@ def backward(ctx, d_out):
ctx.window_size,
ctx.deterministic,
)
if DEBUG_BLOCK is not None:
print(
"After Inside attention:"
f" {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}"
)

# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def fused_attn_bwd(
len(aux_ctx_tensors) == 3
), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention."

import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb

debug = bbb.DEBUG_BLOCK.self_attention.proj.weight._scale_inv
output_tensors = tex.fused_attn_bwd(
max_seqlen_q,
max_seqlen_kv,
Expand All @@ -471,6 +474,7 @@ def fused_attn_bwd(
s_quantizer,
dp_quantizer,
dqkv_quantizer,
debug,
)

return output_tensors
7 changes: 1 addition & 6 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantize
// also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer
for (auto [check_type, check_quantizer_type, create_tensor, _] :
detail::custom_types_converters) {
if (check_type(tensor.ptr())) {
if (check_type(tensor.ptr()) != PythonTensorType::INVALID) {
if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) {
continue;
}
Expand Down Expand Up @@ -286,9 +286,4 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

int roundup(const int value, const int multiple) {
assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple;
}

} // namespace transformer_engine::pytorch
23 changes: 16 additions & 7 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Quantizer {
virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

virtual std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;
Comment on lines +101 to 102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat orthogonal, but since we're touching Quantizer::create_tensor, we should consider removing the rowwise_data arg. It was a UB-specific option that doesn't really make sense anymore. I believe all usages have been refactored away.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, cool. I will do that - it will make the code nicer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can't do that just yet. Attention also uses this unfortunately.


virtual ~Quantizer() = default;
Expand All @@ -121,7 +121,7 @@ class NoneQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override {}

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

Expand All @@ -139,7 +139,7 @@ class Float8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

Expand All @@ -161,7 +161,7 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

Expand Down Expand Up @@ -195,7 +195,7 @@ class Float8BlockQuantizer : public Quantizer {
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

Expand All @@ -210,7 +210,7 @@ class MXFP8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
const std::vector<size_t>& shape, DType dtype, const py::object& output = py::none(),
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

Expand Down Expand Up @@ -354,7 +354,16 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);

std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);
template <typename T>
T divup(const T value, const T multiple) {
assert(multiple > 0);
return ((value + multiple - 1) / multiple);
}

template <typename T>
T roundup(const T value, const T multiple) {
return divup(value, multiple) * multiple;
}

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ std::vector<py::object> fused_attn_bwd(
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer);
py::handle dp_quantizer, py::handle dqkv_quantizer, at::Tensor debug);

at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
Expand Down
Loading