Skip to content

Commit d63c7fa

Browse files
allow for bf16 training
1 parent bf1383e commit d63c7fa

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

torchao/float8/benchmarking/float8_training_benchmark.sh

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# with the given parameters,
44

55
# script arguments
6-
RECIPE=${RECIPE:-"tensorwise"}
76
BATCH_SIZE=${BATCH_SIZE:-1}
87
STEPS=${STEPS:-100}
98

@@ -15,26 +14,32 @@ if [ -z "${TORCHTITAN_ROOT}" ]; then
1514
echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script."
1615
echo "Usage: TORCHTITAN_ROOT=<directory> ./float8_training_benchmark.sh"
1716
echo "Optional parameters configurable via environment variables:"
18-
echo " * RECIPE: rowwise|tensorwise. defaults to tensorwise."
17+
echo " * FLOAT8_RECIPE: "rowwise" or "tensorwise". if set, use float8 training with the specified recipe. otherwise, use bf16 mixed precision training."
1918
echo " * BATCH_SIZE: defaults to 1."
2019
echo " * STEPS: defaults to 100."
2120
exit 1
2221
fi
2322

2423
# validate recipe name
25-
if [ "$RECIPE" != "rowwise" ] && [ "$RECIPE" != "tensorwise" ]; then
26-
echo "Error: RECIPE must be 'rowwise' or 'tensorwise'"
27-
exit 1
24+
if [ -n "${FLOAT8_RECIPE}" ]; then
25+
if [ "$FLOAT8_RECIPE" != "rowwise" ] && [ "$FLOAT8_RECIPE" != "tensorwise" ]; then
26+
echo "Error: RECIPE must be 'rowwise' or 'tensorwise'"
27+
exit 1
28+
fi
29+
FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE}"
2830
fi
2931

32+
3033
# remember current directory to return to it later
3134
original_dir=$(pwd)
3235

3336
# navigate to torchtitan root dir
3437
cd ${TORCHTITAN_ROOT}
3538

39+
echo "float8 args: ${FLOAT8_ARGS}"
40+
3641
# run the command with the specified arguments
37-
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=$BATCH_SIZE --training.compile --model.converters="float8" --float8.recipe_name=$RECIPE 2>&1 | tee ${LOG_FILE}
42+
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} 2>&1 | tee ${LOG_FILE}
3843

3944
# return to original working directory
4045
cd $original_dir

0 commit comments

Comments
 (0)