Skip to content

BF16 stochastic rounding does not work distributed (FSDP) #2296

Open
@nathan-az

Description

@nathan-az

Error arose using torchtune. Config and stacktrace are below. Testing was done on a feature branch but with no non-standard features enabled. It works fine with nproc=1, but with FSDP or TP enabled, it breaks.

torchtune config

batch_size: 1
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00004'
  model_type: LLAMA3
  output_dir: ${output_dir}
  recipe_checkpoint: null
  checkpoint_dir: models/llama_3_1_8b
clip_grad_norm: null
compile: false
custom_sharded_layers: []
data_parallel_shard_dim: 2
data_parallel_replicate_dim: 1
tensor_parallel_dim: 1
enable_fp8_training: false
enable_loss_parallel: true
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: true
  train_on_input: true
  split: train[:10%]
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
activation_offloading_use_streams: false
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 10
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
  experiment_name: llama_3_1_8b_debug
  run_name: debug_stochastic_round_dp1
optimizer:
  _component_: torchao.optim.AdamW8bit
  bf16_stochastic_round: true
  lr: 4.0e-05
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
  max_seq_len: 8192
  path: models/llama_3_1_8b/original/tokenizer.model
  _component_: torchtune.models.llama3.llama3_tokenizer
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

Stacktrace

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 1118, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 1113, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/workspace/torchtune/recipes/full_finetune_distributed.py", line 982, in train
[rank0]:     self._optimizer.step()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/optim/optimizer.py", line 506, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/optim/adam.py", line 134, in step
[rank0]:     torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1469, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 625, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1092, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 779, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 818, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 265, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 743, in transform
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3484, in run
[rank0]:     super().run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1359, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1263, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 831, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2903, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2897, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1189, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
[rank0]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 516, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 291, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1206, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3678, in inline_call
[rank0]:     return tracer.inline_call_()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3881, in inline_call_
[rank0]:     self.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1359, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1263, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2161, in COMPARE_OP
[rank0]:     self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 1168, in call_function
[rank0]:     return handler(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 1128, in _handle_insert_op_in_graph
[rank0]:     return wrap_fx_proxy(tx, proxy)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2500, in wrap_fx_proxy
[rank0]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2566, in wrap_fx_proxy_cls
[rank0]:     return _wrap_fx_proxy(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2664, in _wrap_fx_proxy
[rank0]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3347, in get_fake_value
[rank0]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3245, in get_fake_value
[rank0]:     ret_val = wrap_fake_exception(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2745, in wrap_fake_exception
[rank0]:     return fn()
[rank0]:            ^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3246, in <lambda>
[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3454, in run_node
[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3413, in run_node
[rank0]:     return node.target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 157, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 356, in unwrap_to_op_info
[rank0]:     self._try_replicate_spec_for_scalar_tensor(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 458, in _try_replicate_spec_for_scalar_tensor
[rank0]:     raise RuntimeError(
[rank0]: torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function lt>(*(FakeTensor(..., device='cuda:0', size=(128256, 4096), dtype=torch.int32), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(64128, 4096), dtype=torch.int32), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('dp_shard',)), placements=(Shard(dim=0),))), **{}): got RuntimeError('aten.lt.Tensor: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!')

[rank0]: from user code:
[rank0]:    File "/opt/conda/lib/python3.11/site-packages/torchao/optim/adam.py", line 199, in single_param_adam
[rank0]:     p.copy_(_fp32_to_bf16_sr(p_f32))
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchao/optim/quant_utils.py", line 138, in _fp32_to_bf16_sr
[rank0]:     rand_16bit < x_fraction,  # this is True with the probability of p_fraction

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions