diff --git a/.github/workflows/prototype-transforms-tests-linux-gpu.yml b/.github/workflows/prototype-transforms-tests-linux-gpu.yml new file mode 100644 index 00000000000..e5740886b22 --- /dev/null +++ b/.github/workflows/prototype-transforms-tests-linux-gpu.yml @@ -0,0 +1,70 @@ +name: Prototype transforms unit-tests on Linux GPU + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +env: + CHANNEL: "nightly" + +jobs: + tests: + strategy: + matrix: + python_version: ["3.8"] + cuda_arch_version: ["11.6"] + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.4xlarge.nvidia.gpu + repository: pytorch/vision + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 + script: | + # Mark Build Directory Safe + git config --global --add safe.directory /__w/vision/vision + + # Set up Environment Variables + export PYTHON_VERSION="${{ matrix.python_version }}" + export VERSION="${{ matrix.cuda_arch_version }}" + export CUDATOOLKIT="pytorch-cuda=${VERSION}" + + # Set CHANNEL + if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then + export CHANNEL=test + else + export CHANNEL=nightly + fi + + # Create Conda Env + conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy + conda activate /work/ci_env + + # Install PyTorch, Torchvision, and testing libraries + set -ex + conda install \ + --yes \ + -c "pytorch-${CHANNEL}" \ + -c nvidia \ + pytorch \ + torchdata \ + "${CUDATOOLKIT}" + + python3 -c "import torch; exit(not torch.cuda.is_available())" + + python3 setup.py develop + python3 -m pip install pytest pytest-mock pytest-cov + + # Run Tests + python3 -m torch.utils.collect_env + python3 -m pytest \ + --durations=20 \ + --cov=torchvision/prototype/transforms \ + --cov-report=term-missing \ + test/test_prototype_transforms*.py diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 9d97b6ca701..ded888a4a00 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1940,11 +1940,13 @@ def sample_inputs_adjust_contrast_video(): closeness_kwargs={ **pil_reference_pixel_difference(), **float32_vs_uint8_pixel_difference(2), + **cuda_vs_cpu_pixel_difference(), }, ), KernelInfo( F.adjust_contrast_video, sample_inputs_fn=sample_inputs_adjust_contrast_video, + closeness_kwargs=cuda_vs_cpu_pixel_difference(), ), ] )