diff --git a/.github/scripts/validate_test_ops.sh b/.github/scripts/validate_test_ops.sh index d8031c071..00bca8d79 100644 --- a/.github/scripts/validate_test_ops.sh +++ b/.github/scripts/validate_test_ops.sh @@ -6,12 +6,19 @@ retry () { $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) } +BRANCH = "@main" +if [[ ${MATRIX_CHANNEL} == "test" ]] + SHORT_VERSION=${MATRIX_STABLE_VERSION%.*} + BRANCH="@release/${SHORT_VERSION}" +fi + + # Clone the Pytorch branch -retry git clone --depth 1 https://github.com/pytorch/pytorch.git +retry git clone --depth 1 https://github.com/pytorch/pytorch.git${BRANCH} retry git submodule update --init --recursive pushd pytorch -pip install expecttest pyyaml jinja2 +pip install expecttest pyyaml jinja2 packaging # Run test_ops validation export CUDA_LAUNCH_BLOCKING=1