diff --git a/fastfold/habana/fastnn/__init__.py b/fastfold/habana/fastnn/__init__.py index 8993c9d8..5ab4330b 100644 --- a/fastfold/habana/fastnn/__init__.py +++ b/fastfold/habana/fastnn/__init__.py @@ -12,6 +12,7 @@ from .ops import Linear, OutProductMean from .triangle import PairStack +import habana_frameworks.torch.core as htcore class Evoformer(nn.Module): @@ -90,7 +91,6 @@ def forward( m = m[:, :-padding_size, :] z = z[:-padding_size, :-padding_size, :] - import habana_frameworks.torch.core as htcore htcore.mark_step() return m, z @@ -220,7 +220,6 @@ def forward( s = self.linear(m[..., 0, :, :]) - import habana_frameworks.torch.core as htcore htcore.mark_step() return m, z, s @@ -254,7 +253,6 @@ def forward( _mask_trans: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - import habana_frameworks.torch.core as htcore htcore.mark_step() dap_size = dist.get_world_size() diff --git a/fastfold/habana/fastnn/custom_op/README.md b/fastfold/habana/fastnn/custom_op/README.md new file mode 100644 index 00000000..4e283d7e --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/README.md @@ -0,0 +1,49 @@ +# CustomOp API Usage in PyTorch + +This README provides an example of how to write custom PyTorch Ops using a TPC Kernel supported on an HPU device. For more details, refer to [PyTorch CustomOP API](https://docs.habana.ai/en/latest/PyTorch/PyTorch_CustomOp_API/page_index.html) documentation. + +For further information on training deep learning models using Gaudi, refer to [developer.habana.ai](https://developer.habana.ai/resources/). + +## Table of Contents + +* [Model-References](../../../README.md) +* [Prerequisites](#prerequisites) +* [Content](#content) +* [Build and Run with Custom Kernels](#build-and-run-with-custom-kernels) +* [Important to Know](#important-to-know) +* [Applying CustomOps to a Real Training Model Example](#applying-customops-to-a-real-training-model-example) +* [Known Issues](#known-issues) + + +## Prerequisites + +- A TPC kernel on which the HpuKernel will run. To write a CustomOp, you must define the TPC kernel that HpuKernel will run on first. This document provides the required steps for using the existing default TPC kernels `relu_fwd_f32`, `relu_bwd_f32` as we all as the custom kernel `custom_op::custom_relu` to implement CustomOp. For further information on how to write TPC kernels, refer to the [Habana Custom Kernel GitHub page](https://github.com/HabanaAI/Habana_Custom_Kernel). + +- **habana-torch-plugin** Python package must be installed. Make sure to install by following the instructions detailed in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html). + +## Content + +- C++ file with **custom_op::fusedsoftmax**, **custom_op::fusedsoftmax_bias** definition and Kernel implementation on HPU: + - `fusedsoftmax` performs a fused softmax on input and mask. + - `fusedsoftmax_bias` performs a fused softmax on input, mask and bias +- `setup.py` file for building the solution: + - To compile to Op on Gaudi, run ```python setup.py build```. + - To compile to Op on Gaudi2, run ```python setup2.py build```. + +- Python test to run and validate `fusedsoftmax` and `fusedsoftmax_bias`: + - ```python hpu_fusedsoftmax_test.py``` + +## Build and Run with Custom Kernels + +To build and run `fused_softmax` and `fusedsoftmax_bias`, run the following: +```python setup.py build``` + +## Important to Know + +This is an example of an Op implementing both forward and backward. +The forward and backward CustomOp is used for training the model by extending the [torch.autograd](https://pytorch.org/docs/stable/notes/extending.html) package. + +## Known Issues + +BF16 or HMP is not supported yet. To use CustomOp in topology, run FP32 variant only. + diff --git a/fastfold/habana/fastnn/custom_op/__init__.py b/fastfold/habana/fastnn/custom_op/__init__.py new file mode 100644 index 00000000..ec5c81f2 --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/__init__.py @@ -0,0 +1,8 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +from .fusedsoftmax import fused_softmax, fused_softmax_bias + +__all__ = [fused_softmax, fused_softmax_bias] + diff --git a/fastfold/habana/fastnn/custom_op/fusedsoftmax.py b/fastfold/habana/fastnn/custom_op/fusedsoftmax.py new file mode 100644 index 00000000..8f2166e1 --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/fusedsoftmax.py @@ -0,0 +1,81 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +import torch +import os +import habana_frameworks.torch.core + +custom_fusedsoftmax_op_lib_path = "./build/lib.linux-x86_64-3.8/hpu_fusedsoftmax.cpython-38-x86_64-linux-gnu.so" +my_dir = os.path.realpath(__file__) +my_len = my_dir.rfind('/') +base_dir = my_dir[:my_len] +torch.ops.load_library(os.path.join(base_dir, custom_fusedsoftmax_op_lib_path)) + +class FusedSoftmaxFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask, dim): + # ctx is a context object that can be used to stash information + # for backward computation + tensor = torch.ops.custom_op.fusedsoftmax(input, mask, dim) + ctx.y = tensor + ctx.dim = dim + return tensor + + @staticmethod + def backward(ctx, grad_output): + if grad_output is None: + return None + y = ctx.y + ctx.y = None + dim = ctx.dim + ctx.dim = None + grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim) + return grad_input, None, None + +class FusedSoftmaxBiasFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask, bias, dim): + # ctx is a context object that can be used to stash information + # for backward computation + tensor = torch.ops.custom_op.fusedsoftmax_bias(input, mask, bias, dim) + ctx.y = tensor + ctx.dim = dim + ctx.use_bias = False + if bias is not None: + ctx.use_bias = True + return tensor + + @staticmethod + def backward(ctx, grad_output): + if grad_output is None: + return None + y = ctx.y + ctx.y = None + dim = ctx.dim + ctx.dim = None + grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim) + + grad_bias = None + if ctx.use_bias: + grad_bias = torch.sum(grad_input, dim=-4, keepdim=True) + + return grad_input, None, grad_bias, None + + +ENABLE_OPT = True + +def fused_softmax(input, mask, dim): + if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape: + return FusedSoftmaxFunction.apply(input, mask, dim) + else: + input += mask + return torch.softmax(input, dim=dim) + +def fused_softmax_bias(input, mask, bias, dim): + if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape and input[..., :1, :, :, :].shape == bias.shape: + return FusedSoftmaxBiasFunction.apply(input, mask, bias, dim) + else: + input += mask + input += bias + return torch.softmax(input, dim=dim) diff --git a/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax.cpp b/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax.cpp new file mode 100644 index 00000000..7a947d2e --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax.cpp @@ -0,0 +1,241 @@ +/****************************************************************************** +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### +*******************************************************************************/ + +#include "hpu_custom_op.h" +#include +#include + +struct SoftMaxParam +{ + int32_t axis; + bool with_bias; +}; + +bool register_fusedsoftmax() { + // Registering custom_op::fusedsoftmax + // inputs desc + habana::custom_op::InputDesc input_a_desc{ + habana::custom_op::input_type::TENSOR, 0}; + habana::custom_op::InputDesc input_b_desc{ + habana::custom_op::input_type::TENSOR, 1}; + habana::custom_op::InputDesc input_d_desc{ + habana::custom_op::input_type::USER_PARAMS, 2}; + std::vector inputs_desc{ + input_a_desc, input_b_desc, input_d_desc}; + + // output desc + // output shape callback + auto output_size_lambda = + [](const at::Stack& inputs) -> std::vector { + auto self = inputs[0].toTensor(); // input + std::vector result_sizes = self.sizes().vec(); + return result_sizes; + }; + + habana::custom_op::OutputDesc output_desc{ + 0, c10::ScalarType::Float, output_size_lambda}; + + std::vector outputs_desc{ + output_desc}; + + // user param callback + auto user_params_lambda = [](const at::Stack& inputs, size_t& size) { + HPU_PARAMS_STUB(SoftMaxParam); + params->with_bias = false; + int dim = inputs[2].toInt(); + if (dim > 0) + params->axis = inputs[0].toTensor().dim() - dim - 1; + else + params->axis = - dim - 1; + + return params; + }; + + // actual register + REGISTER_CUSTOM_OP_ATTRIBUTES( + "custom_op::fusedsoftmax", //schema name +#ifdef GAUDI2 + "fusedsoftmax_fwd_f32_gaudi2", // guid +#else + "fusedsoftmax_fwd_f32", // guid +#endif + inputs_desc, + outputs_desc, + user_params_lambda); + std::cout << "cpp registered custom_op::fusedsoftmax\n"; + return true; +} + +bool register_fusedsoftmax_bias() { + // Registering custom_op::fusedsoftmax + // inputs desc + habana::custom_op::InputDesc input_a_desc{ + habana::custom_op::input_type::TENSOR, 0}; + habana::custom_op::InputDesc input_b_desc{ + habana::custom_op::input_type::TENSOR, 1}; + habana::custom_op::InputDesc input_c_desc{ + habana::custom_op::input_type::TENSOR, 2}; + habana::custom_op::InputDesc input_d_desc{ + habana::custom_op::input_type::USER_PARAMS, 3}; + std::vector inputs_desc{ + input_a_desc, input_b_desc, input_c_desc, input_d_desc}; + + // output desc + // output shape callback + auto output_size_lambda = + [](const at::Stack& inputs) -> std::vector { + auto self = inputs[0].toTensor(); // input + std::vector result_sizes = self.sizes().vec(); + return result_sizes; + }; + + habana::custom_op::OutputDesc output_desc{ + 0, c10::ScalarType::Float, output_size_lambda}; + + std::vector outputs_desc{ + output_desc}; + + // user param callback + auto user_params_lambda = [](const at::Stack& inputs, size_t& size) { + HPU_PARAMS_STUB(SoftMaxParam); + params->with_bias = true; + int dim = inputs[3].toInt(); + if (dim > 0) + params->axis = inputs[0].toTensor().dim() - dim - 1; + else + params->axis = - dim - 1; + + return params; + }; + + // actual register + REGISTER_CUSTOM_OP_ATTRIBUTES( + "custom_op::fusedsoftmax_bias", //schema name +#ifdef GAUDI2 + "fusedsoftmax_bias_fwd_f32_gaudi2", // guid +#else + "fusedsoftmax_bias_fwd_f32", // guid +#endif + inputs_desc, + outputs_desc, + user_params_lambda); + std::cout << "cpp registered custom_op::fusedsoftmax_bias\n"; + return true; +} + +bool register_custom_fusedsoftmax_backward() { + // inputs desc + habana::custom_op::InputDesc y_desc{ + habana::custom_op::input_type::TENSOR, 0}; + habana::custom_op::InputDesc grad_desc{ + habana::custom_op::input_type::TENSOR, 1}; + habana::custom_op::InputDesc dim_desc{ + habana::custom_op::input_type::USER_PARAMS, 2}; + + std::vector inputs_desc{ + y_desc, grad_desc, dim_desc}; + + auto output_input_size_lambda = + [](const at::Stack& inputs) -> std::vector { + auto self = inputs[0].toTensor(); // input + std::vector result_sizes = self.sizes().vec(); + return result_sizes; + }; + + habana::custom_op::OutputDesc input_grad_desc{ + 0, c10::ScalarType::Float, output_input_size_lambda}; + + std::vector outputs_desc{ + input_grad_desc}; + + // user param callback + auto user_params_lambda = [](const at::Stack& inputs, size_t& size) { + HPU_PARAMS_STUB(ns_Softmax::Params); + params->dim = 0; + return params; + }; + + // actual register + REGISTER_CUSTOM_OP_ATTRIBUTES( + "custom_op::fusedsoftmax_backward", //schema name +#ifdef GAUDI2 + "softmax_bwd_f32", // guid +#else + "softmax_bwd_f32", // guid +#endif + inputs_desc, + outputs_desc, + user_params_lambda); + std::cout << "cpp registered custom_op::fusedsoftmax_backward\n"; + return true; +} + +at::Tensor fusedsoftmax_execute( + torch::Tensor input, + torch::Tensor mask, + at::Scalar dim) { + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor"); + // Registering the custom op, need to be called only once + static bool registered = register_fusedsoftmax(); + TORCH_CHECK(registered, "fusedsoftmax kernel not registered" ); + std::vector inputs{input, mask, dim}; + // Get custom op descriptor from registry + auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax"); + // Actual call for op execution + std::vector output = op_desc.execute(inputs); + // op_desc.execute will always return a vector + return output[0]; +} + +at::Tensor fusedsoftmax_bias_execute( + torch::Tensor input, + torch::Tensor mask, + torch::Tensor bias, + at::Scalar dim) { + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor"); + // Registering the custom op, need to be called only once + static bool registered = register_fusedsoftmax_bias(); + TORCH_CHECK(registered, "fusedsoftmax_bias kernel not registered" ); + std::vector inputs{input, mask, bias, dim}; + // Get custom op descriptor from registry + auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_bias"); + // Actual call for op execution + std::vector output = op_desc.execute(inputs); + // op_desc.execute will always return a vector + return output[0]; +} + +at::Tensor fusedsoftmax_backward_execute( + torch::Tensor y, + torch::Tensor grad, + at::Scalar dim) { + TORCH_CHECK(y.scalar_type() == c10::ScalarType::Float, "Input y expected to be Float tensor"); + TORCH_CHECK(grad.scalar_type() == c10::ScalarType::Float, "Input grad expected to be Float tensor"); + + // Registering the custom op, need to be called only once + static bool registered = register_custom_fusedsoftmax_backward(); + TORCH_CHECK(registered, "custom_fusedsoftmax_backward kernel not registered" ); + std::vector inputs{y, grad, dim}; + // Get custom op descriptor from registry + auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_backward"); + // Actual call for op execution + std::vector output = op_desc.execute(inputs); + // op_desc.execute will always return a vector + return output[0]; +} + +TORCH_LIBRARY(custom_op, m) { + m.def("fusedsoftmax(Tensor self, Tensor mask, Scalar dim) -> Tensor"); + m.def("fusedsoftmax_bias(Tensor self, Tensor mask, Tensor bias, Scalar dim) -> Tensor"); + m.def("fusedsoftmax_backward(Tensor y, Tensor grad, Scalar dim) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(custom_op, HPU, m) { + m.impl("fusedsoftmax", fusedsoftmax_execute); + m.impl("fusedsoftmax_bias", fusedsoftmax_bias_execute); + m.impl("fusedsoftmax_backward", fusedsoftmax_backward_execute); +} + diff --git a/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax_test.py b/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax_test.py new file mode 100644 index 00000000..f13d50a7 --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/hpu_fusedsoftmax_test.py @@ -0,0 +1,113 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +import torch +from fusedsoftmax import fused_softmax, fused_softmax_bias + +def test_fusedsoftmax_op_function(): + print(torch.ops.custom_op.fusedsoftmax) + print(torch.ops.custom_op.fusedsoftmax_bias) + + # print(torch.ops.custom_op.custom_relu_backward) + input = torch.randn(1, 512, 4, 512, 512) + mask = torch.randn(1, 512, 1, 1, 512) + bias = torch.randn(1, 1, 4, 512, 512) + dim = -1 + + input_hpu = input.to('hpu') + mask_hpu = mask.to('hpu') + + out = input + mask + output_cpu = torch.softmax(out, dim=dim) + + output_hpu = fused_softmax(input_hpu, mask_hpu, dim) + + assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all()) + print("fused_softmax test passed") + + input_hpu = input.to('hpu') + mask_hpu = mask.to('hpu') + bias_hpu = bias.to('hpu') + out = input + mask + out += bias + output_cpu = torch.softmax(out, dim=dim) + + output_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim); + + assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all()) + print("fused_softmax_bias test passed") + +test_fusedsoftmax_op_function() + + +def test_fusedsoftmax_bias_op_backward_function(): + print("fused_softmax_bias_backward") + input = torch.randn(1, 512, 4, 512, 512, requires_grad=True) + mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False) + bias = torch.randn(1, 1, 4, 512, 512, requires_grad=True) + dim = -1 + + # cpu reference + add_mask_cpu = input + mask + add_mask_cpu += bias + softmax_cpu = torch.softmax(add_mask_cpu, dim=dim) + + input_hpu = input.to('hpu').detach() + input_hpu.requires_grad = True + mask_hpu = mask.to('hpu').detach() + mask_hpu.requires_grad = False + bias_hpu = bias.to('hpu').detach() + bias_hpu.requires_grad = True + softmax_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim) + + assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all()) + + grad_cpu = torch.ones_like(softmax_cpu) + softmax_cpu.backward(grad_cpu) + grad_hpu = grad_cpu.to('hpu') + softmax_hpu.backward(grad_hpu) + + input_bwd_cpu = input.grad + input_bwd_hpu = input_hpu.grad + assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all()) + bias_bwd_cpu = bias.grad + bias_bwd_hpu = bias_hpu.grad + assert((abs(bias_bwd_hpu.detach().cpu() - bias_bwd_cpu.detach()) < 1e-6).all()) + + print("fused_softmax_bias_backward test passed") + + +test_fusedsoftmax_bias_op_backward_function() + +def test_fusedsoftmax_op_backward_function(): + print(torch.ops.custom_op.fusedsoftmax_backward) + input = torch.randn(1, 512, 4, 512, 512, requires_grad=True) + mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False) + dim = -1 + + # cpu reference + add_mask_cpu = input + mask + softmax_cpu = torch.softmax(add_mask_cpu, dim=dim) + + input_hpu = input.to('hpu').detach() + input_hpu.requires_grad = True + mask_hpu = mask.to('hpu').detach() + mask_hpu.requires_grad = False + softmax_hpu = fused_softmax(input_hpu, mask_hpu, dim) + + assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all()) + + grad_cpu = torch.ones_like(softmax_cpu) + softmax_cpu.backward(grad_cpu) + grad_hpu = grad_cpu.to('hpu') + softmax_hpu.backward(grad_hpu) + + input_bwd_cpu = input.grad + input_bwd_hpu = input_hpu.grad + assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all()) + + print("fused_softmax_backward test passed") + + +test_fusedsoftmax_op_backward_function() diff --git a/fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so b/fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so new file mode 100755 index 00000000..6315ba7d Binary files /dev/null and b/fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so differ diff --git a/fastfold/habana/fastnn/custom_op/setup.py b/fastfold/habana/fastnn/custom_op/setup.py new file mode 100644 index 00000000..7acb0a20 --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/setup.py @@ -0,0 +1,26 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +from setuptools import setup +from torch.utils import cpp_extension +from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir +import os +import pybind11 + +torch_include_dir = get_include_dir() +torch_lib_dir = get_lib_dir() +habana_modules_directory = "/usr/include/habanalabs" +pybind_include_path = pybind11.get_include() + +setup(name='hpu_fusedsoftmax', + ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'], + language='c++', extra_compile_args=["-std=c++17"], + libraries=['habana_pytorch_plugin'], + library_dirs=[torch_lib_dir])], + include_dirs=[torch_include_dir, + habana_modules_directory, + pybind_include_path, + ], + cmdclass={'build_ext': cpp_extension.BuildExtension}) + diff --git a/fastfold/habana/fastnn/custom_op/setup2.py b/fastfold/habana/fastnn/custom_op/setup2.py new file mode 100644 index 00000000..1d3858f8 --- /dev/null +++ b/fastfold/habana/fastnn/custom_op/setup2.py @@ -0,0 +1,26 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +from setuptools import setup +from torch.utils import cpp_extension +from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir +import os +import pybind11 + +torch_include_dir = get_include_dir() +torch_lib_dir = get_lib_dir() +habana_modules_directory = "/usr/include/habanalabs" +pybind_include_path = pybind11.get_include() + +setup(name='hpu_fusedsoftmax', + ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'], + language='c++', extra_compile_args=["-std=c++17"], define_macros=[("GAUDI2", None)], + libraries=['habana_pytorch_plugin'], + library_dirs=[torch_lib_dir])], + include_dirs=[torch_include_dir, + habana_modules_directory, + pybind_include_path, + ], + cmdclass={'build_ext': cpp_extension.BuildExtension}) + diff --git a/fastfold/habana/fastnn/ops.py b/fastfold/habana/fastnn/ops.py index 1faa0c00..284c1e35 100755 --- a/fastfold/habana/fastnn/ops.py +++ b/fastfold/habana/fastnn/ops.py @@ -9,6 +9,9 @@ from .initializer import glorot_uniform_af from .kernel import bias_sigmod_ele +from fastfold.habana.distributed import gather, scatter +from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias + CHUNK_SIZE = None DEBUG = False @@ -103,9 +106,17 @@ def forward(self, M, M_mask, Z_raw): for ax in range(0, para_dim, chunk_size): left_act_part = left_act[:, :, ax:ax + chunk_size, :] - O = torch.einsum('sid,sje->ijde', left_act_part.squeeze(0), right_act_all.squeeze(0)) + # O = torch.einsum('sid,sje->ijde', left_act_part.squeeze(0), right_act_all.squeeze(0)) - O = rearrange(O, 'i j d e -> i j (d e)') + # O = rearrange(O, 'i j d e -> i j (d e)') + left_shape = left_act_part.shape + right_shape = right_act_all.shape + left_act_part = left_act_part.reshape(left_shape[0], left_shape[1], left_shape[2]*left_shape[3]) + right_act_all = right_act_all.reshape(right_shape[0], right_shape[1], right_shape[2]*right_shape[3]) + # O = torch.einsum('...ab,...ad->...bd', left_act_part.squeeze(0), right_act_all.squeeze(0)) + O = torch.matmul(left_act_part.squeeze(0).transpose(1, 0), right_act_all.squeeze(0)) + O = O.reshape(left_shape[2], left_shape[3], right_shape[2], right_shape[3]).transpose(-2, -3) + O = O.reshape(O.shape[0], O.shape[1], O.shape[2]*O.shape[3]) O = O.unsqueeze(0) @@ -164,10 +175,10 @@ def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=Fals self.scaling = self.c**(-0.5) - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear', use_bias=False) + # self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + # self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + # self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) if gating: self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) @@ -196,25 +207,31 @@ def forward(self, in_data, mask, nonbatched_bias=None): in_data_part = in_data[:, ax:ax + chunk_size, :, :] mask_part = mask[:, ax:ax + chunk_size, :] - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) + qkv = self.to_qkv(in_data_part).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - q = self.to_q(in_data_part) - k = self.to_k(in_data_part) - v = self.to_v(in_data_part) + # q = self.to_q(in_data_part) + # k = self.to_k(in_data_part) + # v = self.to_v(in_data_part) - q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - [q, k, v]) + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), + # [q, k, v]) q = q * self.scaling logits = torch.matmul(q, k.transpose(-1, -2)) - logits += (1e9 * (mask_part - 1))[..., :, None, None, :] + # logits += (1e9 * (mask_part - 1))[..., :, None, None, :] + # if nonbatched_bias is not None: + # logits += nonbatched_bias.unsqueeze(1) + # weights = torch.softmax(logits, dim=-1) + + mask00 = (1e9 * (mask_part - 1))[..., :, None, None, :] if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) + weights = fused_softmax_bias(logits, mask00, nonbatched_bias.unsqueeze(1), -1) + else: + weights = fused_softmax(logits, mask00, -1) weighted_avg = torch.matmul(weights, v) weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') diff --git a/fastfold/habana/fastnn/triangle.py b/fastfold/habana/fastnn/triangle.py index 0d138fdc..5360835e 100644 --- a/fastfold/habana/fastnn/triangle.py +++ b/fastfold/habana/fastnn/triangle.py @@ -51,13 +51,13 @@ def forward(self, Z_raw, Z_mask): right_proj_act = gather(right_proj_act.contiguous(), dim=1) g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) + p = torch.matmul( + permute_final_dims(left_proj_act, (2, 0, 1)), + permute_final_dims(right_proj_act, (2, 1, 0)), + ) + ab = permute_final_dims(p, (1, 2, 0)) - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) + # ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) ab = self.output_projection(self.layernorm2(ab)) dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) return bias_ele_dropout_residual(ab, @@ -102,13 +102,13 @@ def forward(self, Z_raw, Z_mask): left_proj_act = gather(left_proj_act.contiguous(), dim=2) g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) + p = torch.matmul( + permute_final_dims(left_proj_act, (2, 1, 0)), + permute_final_dims(right_proj_act, (2, 0, 1)), + ) + ab = permute_final_dims(p, (1, 2, 0)) - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) + # ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) ab = self.output_projection(self.layernorm2(ab)) dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) return bias_ele_dropout_residual(ab, diff --git a/fastfold/habana/inject_habana.py b/fastfold/habana/inject_habana.py index 2fef55c7..34715808 100644 --- a/fastfold/habana/inject_habana.py +++ b/fastfold/habana/inject_habana.py @@ -49,9 +49,7 @@ def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v): def copy_attention(model_fast, model_ori): - copy_linear(model_fast.to_q, model_ori.linear_q) - copy_linear(model_fast.to_k, model_ori.linear_k) - copy_linear(model_fast.to_v, model_ori.linear_v) + copy_qkv_linear(model_fast.to_qkv, model_ori.linear_q, model_ori.linear_k, model_ori.linear_v) copy_linear(model_fast.gating_linear, model_ori.linear_g) copy_linear(model_fast.o_linear, model_ori.linear_o) diff --git a/fastfold/model/hub/alphafold.py b/fastfold/model/hub/alphafold.py index f2c8b82f..78e2265f 100644 --- a/fastfold/model/hub/alphafold.py +++ b/fastfold/model/hub/alphafold.py @@ -41,6 +41,7 @@ tensor_tree_map, ) +import fastfold.habana as habana class AlphaFold(nn.Module): """ @@ -173,6 +174,9 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): # Primary output dictionary outputs = {} + if habana.is_habana(): + from habana.hpuhelper import hpu_perf + perf = hpu_perf("iteration", sync=False) dtype = next(self.parameters()).dtype for k in feats: if(feats[k].dtype == torch.float32): @@ -190,7 +194,8 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): pair_mask = seq_mask[..., None] * seq_mask[..., None, :] msa_mask = feats["msa_mask"] - # Initialize the MSA and pair representations + if habana.is_habana(): + perf.checkahead("1: Initialize the MSA and pair representations") # m: [*, S_c, N, C_m] # z: [*, N, N, C_z] @@ -252,7 +257,8 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): # Possibly prevents memory fragmentation del m_1_prev, z_prev, x_prev - # Embed the templates + merge with MSA/pair embeddings + if habana.is_habana(): + perf.checkahead("2: Embed the templates + merge with MSA/pair embeddings") if self.config.template.enabled: template_feats = { k: v for k, v in feats.items() if k.startswith("template_") @@ -320,7 +326,8 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): ) del template_feats, template_embeds - # Embed extra MSA features + merge with pairwise embeddings + if habana.is_habana(): + perf.checkahead("3: Embed extra MSA features + merge with pairwise embeddings") if self.config.extra_msa.enabled: if(self.globals.is_multimer): extra_msa_fn = data_transforms_multimer.build_extra_msa_feat @@ -354,7 +361,8 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): )[0] del extra_msa_feat, extra_msa_fn - # Run MSA + pair embeddings through the trunk of the network + if habana.is_habana(): + perf.checkahead("4: Run MSA + pair embeddings through the trunk of the network") # m: [*, S, N, C_m] # z: [*, N, N, C_z] # s: [*, N, C_s] @@ -385,7 +393,8 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): outputs["pair"] = z outputs["single"] = s - # Predict 3D structure + if habana.is_habana(): + perf.checkahead("5: Predict 3D structure") outputs["sm"] = self.structure_module( s, z, @@ -409,6 +418,9 @@ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): # [*, N, 3] x_prev = outputs["final_atom_positions"] + if habana.is_habana(): + perf.checkahead("6: stop iteration") + return outputs, m_1_prev, z_prev, x_prev def _disable_activation_checkpointing(self): @@ -490,6 +502,9 @@ def forward(self, batch): # Main recycling loop num_iters = batch["aatype"].shape[-1] for cycle_no in range(num_iters): + if habana.is_habana(): + from habana.hpuhelper import hpu_perf + perf = hpu_perf(f"cycle {cycle_no+1}/{num_iters}") # Select the features for the current recycling cycle fetch_cur_batch = lambda t: t[..., cycle_no] feats = tensor_tree_map(fetch_cur_batch, batch) @@ -511,7 +526,8 @@ def forward(self, batch): x_prev, _recycle=(num_iters > 1) ) - + if habana.is_habana(): + perf.checknow("cycle finish") # Run auxiliary heads outputs.update(self.aux_heads(outputs)) diff --git a/fastfold/model/hub/loss.py b/fastfold/model/hub/loss.py index f11de43d..d4ec6bc9 100644 --- a/fastfold/model/hub/loss.py +++ b/fastfold/model/hub/loss.py @@ -1612,6 +1612,7 @@ def forward(self, out, batch, _return_breakdown=False): out["sm"]["unnormalized_angles"], **{**batch, **self.config.supervised_chi}, ), + # Habana: TODO comment out below part to WA error in HMP "violation": lambda: violation_loss( out["violation"], **batch, diff --git a/fastfold/model/nn/primitives.py b/fastfold/model/nn/primitives.py index 0c9ffae8..fd46f434 100644 --- a/fastfold/model/nn/primitives.py +++ b/fastfold/model/nn/primitives.py @@ -29,6 +29,7 @@ _chunk_slice, ) +import fastfold.habana as habana def _prod(nums): out = 1 @@ -214,11 +215,17 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, # [*, H, Q, K] a = torch.matmul(query, key) + if habana.is_habana(): + from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias + if len(biases) == 1: + a = fused_softmax(a, biases[0], -1) + else: + a = fused_softmax_bias(a, biases[0], biases[1], -1) + else: + for b in biases: + a += b - for b in biases: - a += b - - a = softmax(a, -1) + a = softmax(a, -1) # [*, H, Q, C_hidden] a = a.to(dtype=value.dtype) @@ -464,8 +471,12 @@ def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] ) bias = (self.inf * (mask - 1))[..., :, None, :] - a += bias - a = softmax(a) + if habana.is_habana(): + from fastfold.habana.fastnn.custom_op import fused_softmax, fused_softmax_bias + a = fused_softmax(a, bias, -1) + else: + a += bias + a = softmax(a) # [*, N_res, H, C_hidden] a = a.to(dtype=v.dtype) diff --git a/fastfold/model/nn/structure_module.py b/fastfold/model/nn/structure_module.py index ce889afb..e421f807 100644 --- a/fastfold/model/nn/structure_module.py +++ b/fastfold/model/nn/structure_module.py @@ -39,6 +39,7 @@ flatten_final_dims, ) +import fastfold.habana as habana class AngleResnetBlock(nn.Module): def __init__(self, c_hidden): @@ -397,10 +398,20 @@ def forward( pt_att = sum([c**2 for c in pt_att]) else: # [*, N_res, N_res, H, P_q, 3] - pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) - pt_att = pt_att**2 - # [*, N_res, N_res, H, P_q] - pt_att = sum(torch.unbind(pt_att, dim=-1)) + ###################################### + q_pts_t0 = q_pts.unsqueeze(-4) + q_shape = q_pts_t0.shape + q_pts_t0 = q_pts_t0.reshape([q_shape[0], q_shape[1], -1]) + k_pts_t0 = k_pts.unsqueeze(-5) + k_shape = k_pts_t0.shape + k_pts_t0 = k_pts_t0.reshape([k_shape[0], k_shape[1], -1]) + q_k = q_pts_t0 - k_pts_t0 + q_k = q_k ** 2 + q_k_shape = q_k.shape + pt_att = q_k.reshape(q_k_shape[:2] + q_shape[-3:]) + ##################################### + pt_att = pt_att.permute(0, 4, 1, 2, 3) + pt_att = torch.sum(pt_att, 1) head_weights = self.softplus(self.head_weights).view( *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) @@ -408,7 +419,12 @@ def forward( head_weights = head_weights * math.sqrt( 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) ) - pt_att = pt_att * head_weights + ############################## + pt_att_t0 = pt_att.permute(0, 3, 1, 2) + head_weights_t0 = head_weights.permute(0, 3, 1, 2) + pt_att_o = pt_att_t0 * head_weights_t0 + pt_att = pt_att_o.permute(0, 2,3, 1) + ############################## # [*, N_res, N_res, H] pt_att = torch.sum(pt_att, dim=-1) * (-0.5) @@ -448,13 +464,14 @@ def forward( o_pt_norm = o_pt.norm(self.eps) else: # [*, H, 3, N_res, P_v] - o_pt = torch.sum( - ( - a[..., None, :, :, None] - * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] - ), - dim=-2, - ) + ################################### + a1 = a[..., None, :, :, None] + a1 = a1.permute(0, 1, 2, 4, 3) + b = permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + b = b.permute(0, 1, 2, 4, 3) + c = a1 * b + o_pt = torch.sum(c, -1) + ################################### # [*, N_res, H, P_v, 3] o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) @@ -788,6 +805,10 @@ def _forward_monomer( if i < (self.no_blocks - 1): rigids = rigids.stop_rot_gradient() + if habana.is_habana(): + import habana_frameworks.torch.core as htcore + htcore.mark_step() + outputs = dict_multimap(torch.stack, outputs) outputs["single"] = s diff --git a/fastfold/model/nn/template.py b/fastfold/model/nn/template.py index 653cc73c..c9c141e7 100644 --- a/fastfold/model/nn/template.py +++ b/fastfold/model/nn/template.py @@ -40,6 +40,7 @@ flatten_final_dims, ) +import fastfold.habana as habana class TemplatePointwiseAttention(nn.Module): """ @@ -121,10 +122,13 @@ def forward(self, # [*, N_res, N_res, 1, C_z] biases = [bias] - if chunk_size is not None: - z = self._chunk(z, t, biases, chunk_size) - else: + if habana.is_habana(): z = self.mha(q_x=z, kv_x=t, biases=biases) + else: + if chunk_size is not None: + z = self._chunk(z, t, biases, chunk_size) + else: + z = self.mha(q_x=z, kv_x=t, biases=biases) # [*, N_res, N_res, C_z] z = z.squeeze(-2) diff --git a/habana/hpuhelper.py b/habana/hpuhelper.py new file mode 100644 index 00000000..42baa84a --- /dev/null +++ b/habana/hpuhelper.py @@ -0,0 +1,43 @@ +import time +import habana_frameworks.torch as ht + +class hpu_perf: + def __init__(self, module, log=True, mark_step=True, memoryinfo=False, sync=False): + if log: + print(f" {module}: start") + self.module = module + self.stime = time.perf_counter() + self.mark = mark_step + self.mem = memoryinfo + self.sync = sync + self.log = log + if self.mem: + ht.hpu.reset_peak_memory_stats() + self.prelog = None + + def checknow(self, log): + if self.mark: + ht.core.mark_step() + if self.sync: + ht.core.hpu.default_stream().synchronize() + if self.mem: + print(ht.hpu.memory_summary()) + + tmp = time.perf_counter() + if self.log: + print(" {}: {} takes {:.2f} ms".format(self.module, log, (tmp - self.stime)*1000)) + self.stime = tmp + + def checkahead(self, log): + if self.mark: + ht.core.mark_step() + if self.sync: + ht.core.hpu.default_stream().synchronize() + if self.mem: + print(ht.hpu.memory_summary()) + + tmp = time.perf_counter() + if self.prelog is not None and self.log: + print(" {}: {} takes {:.2f} ms".format(self.module, self.prelog, (tmp - self.stime)*1000)) + self.stime = tmp + self.prelog = log diff --git a/habana/inference_test.py b/habana/inference_test.py index 2012cad1..e685cbfd 100644 --- a/habana/inference_test.py +++ b/habana/inference_test.py @@ -22,8 +22,8 @@ def main(): config = model_config(model_name) config.globals.inplace = False - config.globals.chunk_size = None - habana.enable_hmp() + config.globals.chunk_size = 512 + # habana.enable_hmp() model = AlphaFold(config) model = inject_habana(model) model = model.eval() @@ -47,6 +47,7 @@ def main(): t = time.perf_counter() out = model(batch) htcore.mark_step() + htcore.hpu.default_stream().synchronize() print(f"Inference time: {time.perf_counter() - t}") diff --git a/habana/train.py b/habana/train.py index e687cf75..667020e5 100644 --- a/habana/train.py +++ b/habana/train.py @@ -22,6 +22,7 @@ torch.multiprocessing.set_sharing_strategy('file_system') +from habana.hpuhelper import * def main(): parser = argparse.ArgumentParser() @@ -187,7 +188,9 @@ def main(): criterion = AlphaFoldLoss(config.loss) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8) + # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8) + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + optimizer = FusedAdamW(model.parameters(), lr=1e-3, eps=1e-8) lr_scheduler = AlphaFoldLRScheduler(optimizer) @@ -202,20 +205,23 @@ def main(): model.train() train_dataloader = tqdm(train_dataloader) for batch in train_dataloader: + perf = hpu_perf("train step") batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()} optimizer.zero_grad() output = model(batch) + perf.checknow("forward") + batch = tensor_tree_map(lambda t: t[..., -1], batch) loss, loss_breakdown = criterion(output, batch, _return_breakdown=True) + perf.checknow("loss") loss.backward() - htcore.mark_step() train_dataloader.set_postfix(loss=float(loss)) + perf.checknow("backward") with hmp.disable_casts(): optimizer.step() - - htcore.mark_step() + perf.checknow("optimizer") lr_scheduler.step()