Skip to content

Nightly pip wheels incompatible with pytorch-triton workflow #1318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ptrblck opened this issue Feb 19, 2023 · 6 comments
Closed

Nightly pip wheels incompatible with pytorch-triton workflow #1318

ptrblck opened this issue Feb 19, 2023 · 6 comments
Assignees

Comments

@ptrblck
Copy link
Contributor

ptrblck commented Feb 19, 2023

Description

Based on pytorch/pytorch#94818 (comment) ptxas should be bundled with "triton" (I assume it should ship in the pytorch-triton wheel), which does not seem to be the case using the latest nightly binary.

Setup info

Collecting environment information...
PyTorch version: 2.0.0.dev20230218+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:35)  [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 525.60.11
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

[removing CPU info as it's not interesting]

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] numpydoc==1.5.0
[pip3] pytorch-triton==2.0.0+c8bfe3f548
[pip3] torch==2.0.0.dev20230218+cu118
[pip3] torchaudio==2.0.0.dev20230218+cu118
[pip3] torchvision==0.15.0.dev20230218+cu118
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] numpydoc                  1.5.0                    pypi_0    pypi
[conda] pytorch-triton            2.0.0+c8bfe3f548          pypi_0    pypi
[conda] torch                     2.0.0.dev20230218+cu118          pypi_0    pypi
[conda] torchaudio                2.0.0.dev20230218+cu118          pypi_0    pypi
[conda] torchvision               0.15.0.dev20230218+cu118          pypi_0    pypi

ptxas in pytorch-triton

The latest nightly tags pytorch-triton==2.0.0+c8bfe3f548 which is correct according to .github/ci_commit_pins/triton.txt.

pytorch-triton searches the ptxas binary using a specified TRITON_PTXAS_PATH or depends on triton/third_party/cuda/bin/ptxas as seen in: https://github.com/openai/triton/blob/c8bfe3f548b164f745ada620a560f87f41ab8465/python/triton/compiler.py#L1066-L1067.

It seems however, triton/third_party/cuda does not contain the expected bin folder as seen in:
https://github.com/openai/triton/tree/c8bfe3f548b164f745ada620a560f87f41ab8465/python/triton/third_party/cuda

Example code snippet with failure

Using a simple RN50 with torch.compile:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet50().cuda()
model = torch.compile(model)

x = torch.randn(1, 3, 224, 224).cuda()
out = model(x)
print(out.shape)

fails with:

  File "/usr/local/lib/python3.8/dist-packages/triton/compiler.py", line 1078, in path_to_ptxas
    raise RuntimeError("Cannot find ptxas")
RuntimeError: Cannot find ptxas

Workaround with TRTON_PTXAS_PATH

Setting the TRITON_PTXAS_PATH to a valid ptxas location (from a locally installed CUDA toolkit) fails either with this error on bare metal:

TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python resnet_compile.py 
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 560, in _worker_compile
    kernel.precompile(warm_cache_only_with_cc=cc)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 69, in precompile
    self.launchers = [
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 70, in <listcomp>
    self._precompile_config(c, warm_cache_only_with_cc)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 83, in _precompile_config
    triton.compile(
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1586, in compile
    so_path = make_stub(name, signature, constants)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1475, in make_stub
    so = _build(name, src_path, tmpdir)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1390, in _build
    ret = subprocess.check_call(cc_cmd)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/subprocess.py", line 364, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp44lexpnp/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include/python3.8', '-I/tmp/tmp44lexpnp', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp44lexpnp/triton_.cpython-38-x86_64-linux-gnu.so']' returned non-zero exit status 1.

or with this error in a docker container:

TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python tmp.py 
LLVM ERROR: Can't find libdevice at neither /usr/local/lib/python3.8/dist-packages/triton/third_party/cuda/lib/libdevice.10.bc nor /tmp/tmp0fe53ca9/triton/python/triton/third_party/cuda/lib/libdevice.10.bc
LLVM ERROR: Can't find libdevice at neither /usr/local/lib/python3.8/dist-packages/triton/third_party/cuda/lib/libdevice.10.bc nor /tmp/tmp0fe53ca9/triton/python/triton/third_party/cuda/lib/libdevice.10.bc

Missing third_party dependency

The second error is strange, as it claims that not even the libdevice.10.bc file can be found and indeed it seems the entire third_party folder is missing:

ls /usr/local/lib/python3.8/dist-packages/triton/
_C  __init__.py  __pycache__  compiler.py  impl  language  ops  runtime  testing.py  tools  utils.py

To double check it, I've downloaded the wheel manually via:

wet https://download.pytorch.org/whl/nightly/pytorch_triton-2.0.0%2Bc8bfe3f548-cp38-cp38-linux_x86_64.whl

and after unzipping it I also cannot find any ptxas, libdevice*, or third_party.

Possible fixes

My best guess right now would be:

  • openai/triton needs to be updated with the bin/ptxas file as it's missing in third_party/cuda
  • the pytoch-triton wheel needs to be rebuilt as it's missing the entire third_party folder
  • I don't know how to understand the /usr/bin/ld: cannot find -lcuda error as gcc -lcuda tries to use my local CUDA toolkit (should this be the case?).

Side note for the last point: my local CUDA toolkit can properly build PyTorch from source and a simple CUDA driver API example with -lcuda, so even if we expect a dependency on a locally installed CUDA toolkit, I'm still unsure why it's failing.

Let me know, if I'm missing something.

CC @malfet @atalman @ngimel

@dllehr-amd
Copy link
Contributor

dllehr-amd commented Feb 20, 2023

I can confirm the latest pytorch_triton whl's are being built devoid of any third_party folder. This will also include libdevice.10.bc and unit tests will fail with
missing libraries, for example

python3 test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_cos_cuda_int64
LLVM ERROR: Can't find libdevice at neither /usr/local/lib/python3.8/dist-packages/triton/third_party/cuda/lib/libdevice.10.bc nor /tmp/tmp0fe53ca9/triton/python/triton/third_party/cuda/lib/libdevice.10.bc
E
======================================================================
ERROR: test_comprehensive_cos_cuda_int64 (__main__.TestInductorOpInfoCUDA)
----------------------------------------------------------------------

I did some research in the build-triton-wheel.yml workflow and it appears setuptools or something related is too old in the
pytorch/manylinux-builder:cpu image. So when Triton tries to include the third_party directory as package_data it gets ignored.
https://github.com/openai/triton/blob/main/python/setup.py#L223

For base images I've tried

Image Setuptools version Used by Includes third_party
pytorch/manylinux-builder:cpu 49.2.1 Pytorch NO
quay.io/pypa/manylinux2014_x86_64:latest 67.2.0 Triton PyPi YES
rocm/dev-manylinux2014_x86_64:5.4.2 67.2.0 pytorch-triton-rocm PyPi YES

@weiwangmeta
Copy link
Contributor

@dllehr-amd @ptrblck thank you for your great analysis! Based on that, pytorch/pytorch#95265 might be able to fix this issue. I have verified the triton wheel build now packaged those thirdparty/cuda files. See https://github.com/pytorch/pytorch/actions/runs/4239265379/jobs/7367146618#step:6:1547

@weiwangmeta
Copy link
Contributor

Re-opening until @ptrblck confirms this is resolved or we have nightly pip wheels that are confirmed to be resolved. Currently, the nightly job is not done, worse, we do not have Feb 21 nightly...

@weiwangmeta weiwangmeta reopened this Feb 22, 2023
atalman pushed a commit to atalman/pytorch that referenced this issue Feb 22, 2023
atalman pushed a commit to atalman/pytorch that referenced this issue Feb 22, 2023
weiwangmeta added a commit to pytorch/pytorch that referenced this issue Feb 22, 2023
* [BE] Cleanup triton builds (#95026)

Remove Python-3.7 clause
Do not install llvm-11, as llvm-14 is installed by triton/python/setup.py script

Pull Request resolved: #95026
Approved by: https://github.com/osalpekar, https://github.com/weiwangmeta

* Upgrade setuptools before building wheels (#95265)

Should fix pytorch/builder#1318

Pull Request resolved: #95265
Approved by: https://github.com/ngimel

---------

Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: Wei Wang <[email protected]>
@ptrblck
Copy link
Contributor Author

ptrblck commented Feb 24, 2023

The latest nightly fixes the original issue! 🎉
Output:

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-triton==2.0.0+d54c04abe2
[pip3] torch==2.0.0.dev20230224+cu118
[pip3] torchaudio==2.0.0.dev20230223+cu118
[pip3] torchvision==0.15.0.dev20230224+cu118

ls /usr/local/lib/python3.10/site-packages/triton/third_party/cuda/lib/
libdevice.10.bc
ls /usr/local/lib/python3.10/site-packages/triton/third_party/cuda/bin
ptxas

In another setup, I'm still running into the gcc -lcuda compile error:

python resnet_compile.py 
/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/tmp/tmpua10w7u7/main.c:2:10: fatal error: cuda.h: No such file or directory
    2 | #include "cuda.h"
      |          ^~~~~~~~
...
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp43bp1id7/main.c', '-O3', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include/python3.8', '-I/tmp/tmp43bp1id7', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp43bp1id7/triton_.cpython-38-x86_64-linux-gnu.so', '-L/usr/lib/x86_64-linux-gnu']' returned non-zero exit status 1.

This issue seems to be unrelated to the missing ptxas binary, so I'll dig into it more and will create another issue if needed.

@ptrblck ptrblck closed this as completed Feb 24, 2023
@weiwangmeta
Copy link
Contributor

I have encountered this too but thought it was my own setup issue. Looks like something else might be going on... Looking forward to your analysis, thanks!

@weiwangmeta
Copy link
Contributor

python resnet_compile.py 
/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/tmp/tmpua10w7u7/main.c:2:10: fatal error: cuda.h: No such file or directory
    2 | #include "cuda.h"
      |          ^~~~~~~~
...
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp43bp1id7/main.c', '-O3', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include/python3.8', '-I/tmp/tmp43bp1id7', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp43bp1id7/triton_.cpython-38-x86_64-linux-gnu.so', '-L/usr/lib/x86_64-linux-gnu']' returned non-zero exit status 1.

I have been testing things with "export CUDA_HOME=/usr/local/cuda-11.7", it turns out that this is blinding me. The above error will occur if CUDA_HOME is not set (or the C_INCLUDE_PATH does not contain "cuda.h")
Thanks to @ngimel and locuslab/pytorch_fft#21 (comment), we have identified a new issue: third_party/cuda/cuda.h is not utilized, making triton package error out if no cuda.h can be found in the path.

cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Mar 5, 2023
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue May 3, 2023
* [BE] Cleanup triton builds (pytorch#95026)

Remove Python-3.7 clause
Do not install llvm-11, as llvm-14 is installed by triton/python/setup.py script

Pull Request resolved: pytorch#95026
Approved by: https://github.com/osalpekar, https://github.com/weiwangmeta

* Upgrade setuptools before building wheels (pytorch#95265)

Should fix pytorch/builder#1318

Pull Request resolved: pytorch#95265
Approved by: https://github.com/ngimel

---------

Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: Wei Wang <[email protected]>
jhavukainen pushed a commit to kulinseth/pytorch that referenced this issue Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants