diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index fa2f58bd95..05b9ece53c 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -62,5 +62,5 @@ jobs: pip install ${{ matrix.torch-spec }} pip install -r requirements.txt pip install -r dev-requirements.txt - pip install . + python setup.py install pytest test --verbose -s diff --git a/dev-requirements.txt b/dev-requirements.txt index 74f75e9093..8a8ed1e491 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,7 +3,8 @@ expecttest parameterized packaging transformers -bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers +bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers matplotlib # needed for triton benchmarking pandas # also for triton benchmarking -transformers #for galore testing \ No newline at end of file +transformers #for galore testing +ninja diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..edf5df1398 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "ninja", "torch"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 27c1f260e8..eb5ba9e3be 100644 --- a/setup.py +++ b/setup.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. import os +import glob from datetime import datetime from setuptools import find_packages, setup current_date = datetime.now().strftime("%Y.%m.%d") - def read_requirements(file_path): with open(file_path, "r") as file: return file.read().splitlines() @@ -22,6 +22,60 @@ def read_requirements(file_path): # Version is year.month.date if using nightlies version = current_date if package_name == "torchao-nightly" else "0.1" +import torch + +from torch.utils.cpp_extension import ( + CppExtension, + CUDAExtension, + BuildExtension, + CUDA_HOME, +) + + +def get_extensions(): + debug_mode = os.getenv('DEBUG', '0') == '1' + if debug_mode: + print("Compiling in debug mode") + + # TODO: And cudatoolkit is available + use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + ] + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + this_dir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(this_dir, "torchao", "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + + if use_cuda: + sources += cuda_sources + + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + + return ext_modules setup( name=package_name, @@ -31,10 +85,12 @@ def read_requirements(file_path): package_data={ "torchao.kernel.configs": ["*.pkl"], }, + ext_modules=get_extensions(), install_requires=read_requirements("requirements.txt"), extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/pytorch-labs/ao", + cmdclass={"build_ext": BuildExtension}, ) diff --git a/test/test_ops.py b/test/test_ops.py new file mode 100644 index 0000000000..6e84d138ad --- /dev/null +++ b/test/test_ops.py @@ -0,0 +1,46 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.optests import opcheck +import torchao +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +import unittest + + +# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): +# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) +class TestOps(TestCase): + def _create_tensors_with_iou(self, N, iou_thresh): + # force last box to have a pre-defined iou with the first box + # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], + # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, + # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh + # Adjust the threshold upward a bit with the intent of creating + # at least one box that exceeds (barely) the threshold and so + # should be suppressed. + boxes = torch.rand(N, 4) * 100 + boxes[:, 2:] += boxes[:, :2] + boxes[-1, :] = boxes[0, :] + x0, y0, x1, y1 = boxes[-1].tolist() + iou_thresh += 1e-5 + boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh + scores = torch.rand(N) + return boxes, scores + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + def test_nms(self): + iou = 0.2 + boxes, scores = self._create_tensors_with_iou(1000, iou) + boxes = boxes.cuda() + scores = scores.cuda() + + # smoke test + _ = torchao.ops.nms(boxes, scores, iou) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index ecd2ccf4b9..d9b73e3583 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -4,6 +4,9 @@ autoquant, ) from . import dtypes +import torch +from . import _C +from . import ops __all__ = [ "dtypes", diff --git a/torchao/csrc/cuda/nms.cu b/torchao/csrc/cuda/nms.cu new file mode 100644 index 0000000000..5bbbff8d79 --- /dev/null +++ b/torchao/csrc/cuda/nms.cu @@ -0,0 +1,181 @@ +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ + for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ + i += (blockDim.x * gridDim.x)) + +#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) + +template +constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +template +__device__ inline bool devIoU( + T const* const a, + T const* const b, + const float threshold) { + T left = max(a[0], b[0]), right = min(a[2], b[2]); + T top = max(a[1], b[1]), bottom = min(a[3], b[3]); + T width = max(right - left, (T)0), height = max(bottom - top, (T)0); + using acc_T = at::acc_type; + acc_T interS = (acc_T)width * height; + acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); + acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); + return (interS / (Sa + Sb - interS)) > threshold; +} + +template +__global__ void nms_kernel_impl( + int n_boxes, + double iou_threshold, + const T* dev_boxes, + unsigned long long* dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + if (row_start > col_start) + return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ T block_boxes[threadsPerBlock * 4]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 4 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; + block_boxes[threadIdx.x * 4 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; + block_boxes[threadIdx.x * 4 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; + block_boxes[threadIdx.x * 4 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 4; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 4, iou_threshold)) { + t |= 1ULL << i; + } + } + const int col_blocks = ceil_div(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +at::Tensor nms_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); + TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); + + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + + at::cuda::CUDAGuard device_guard(dets.device()); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + + int dets_num = dets.size(0); + + const int col_blocks = ceil_div(dets_num, threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dets_sorted.scalar_type(), "nms_kernel", [&] { + nms_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = + (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = + at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < dets_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) + .to(order_t.device(), keep.scalar_type())}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::nms", &nms_kernel); +} + +} // namespace torchao diff --git a/torchao/csrc/init.cpp b/torchao/csrc/init.cpp new file mode 100644 index 0000000000..cb2ec42a45 --- /dev/null +++ b/torchao/csrc/init.cpp @@ -0,0 +1,3 @@ +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/torchao/csrc/nms.cpp b/torchao/csrc/nms.cpp new file mode 100644 index 0000000000..5cc26d1593 --- /dev/null +++ b/torchao/csrc/nms.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); +} diff --git a/torchao/ops.py b/torchao/ops.py new file mode 100644 index 0000000000..0931d32026 --- /dev/null +++ b/torchao/ops.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: + """ + See https://pytorch.org/vision/main/generated/torchvision.ops.nms.html + """ + return torch.ops.torchao.nms.default(boxes, scores, iou_threshold) + + +# Defines the meta kernel / fake kernel / abstract impl +@torch.library.impl_abstract("torchao::nms") +def _(dets, scores, iou_threshold): + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long)