Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.common.recipe import Recipe
from .base import (
get_multi_stream_cublas_workspace,
get_dummy_wgrad,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
Expand Down Expand Up @@ -80,6 +81,7 @@ def forward(
module,
skip_fp8_weight_update,
save_original_input,
fine_grained_activation_offloading,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -209,6 +211,30 @@ def forward(
if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)

for i in range(num_gemms):
weights[i].offloading_activation = False
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same"
f" time."
)

if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list

tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
Expand Down Expand Up @@ -271,11 +297,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]

if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
if (
ctx.cpu_offloading or ctx.fine_grained_activation_offloading
) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
if not ctx.cpu_offloading:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
weights[i] = w
weights[i].main_grad = main_grads[i]
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]

# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
Expand Down Expand Up @@ -426,18 +456,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand Down Expand Up @@ -484,6 +511,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
None,
None,
None,
None,
*wgrad_list,
*grad_biases,
)
Expand Down Expand Up @@ -565,6 +593,7 @@ def __init__(
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
delay_wgrad_compute: bool = False,
save_original_input: bool = False,
) -> None:
Expand All @@ -588,6 +617,8 @@ def __init__(
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name

self.fine_grained_activation_offloading = fine_grained_activation_offloading

self.wgrad_store = WeightGradStore(delay_wgrad_compute)

self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
Expand Down Expand Up @@ -806,6 +837,7 @@ def forward(
self,
skip_fp8_weight_update,
self.save_original_input,
self.fine_grained_activation_offloading,
*weight_tensors,
*bias_tensors,
)
Expand Down
43 changes: 38 additions & 5 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forward(
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
Expand Down Expand Up @@ -424,10 +425,37 @@ def forward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

# Do not offload weights and biases
weight.offloading_activation = False
weightmat.offloading_activation = False
if bias is not None:
bias.offloading_activation = False
ln_weight.offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same"
f" time."
)

if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
Expand Down Expand Up @@ -560,9 +588,11 @@ def backward(

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.has_grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad

Expand Down Expand Up @@ -1021,6 +1051,7 @@ def wgrad_gemm(
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fsdp_group
None, # debug
None, # module
Expand Down Expand Up @@ -1156,6 +1187,7 @@ def __init__(
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
fine_grained_activation_offloading: bool = False,
) -> None:
super().__init__()

Expand All @@ -1172,7 +1204,7 @@ def __init__(
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type

self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name

Expand Down Expand Up @@ -1575,6 +1607,7 @@ def forward(
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fine_grained_activation_offloading,
self.fsdp_group,
self,
skip_fp8_weight_update,
Expand Down
42 changes: 37 additions & 5 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def forward(
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fine_grained_activation_offloading: bool,
fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
Expand Down Expand Up @@ -395,17 +396,43 @@ def forward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

ctx.fine_grained_activation_offloading = fine_grained_activation_offloading

if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same"
f" time."
)

if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
if ctx.has_grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight

# Do not offload weights and biases
weight.offloading_activation = False
weightmat.offloading_activation = False
if bias is not None:
bias.offloading_activation = False
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
Expand Down Expand Up @@ -493,9 +520,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else None
)

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.has_grad_added_to_main_grad:
weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad

Expand Down Expand Up @@ -968,6 +997,7 @@ def wgrad_gemm(
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fine_grained_activation_offloading
None, # fp8_output
None, # fsdp_group
None, # module
Expand Down Expand Up @@ -1090,6 +1120,7 @@ def __init__(
symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False,
name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
) -> None:
super().__init__()

Expand All @@ -1105,7 +1136,7 @@ def __init__(
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name

self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)

if device == "meta":
Expand Down Expand Up @@ -1452,6 +1483,7 @@ def forward(
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
self.fine_grained_activation_offloading,
fp8_output,
self.fsdp_group,
self,
Expand Down