Skip to content

enable all inference and train on Gaudi/Gaudi2 with optimized perf with latest base #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions fastfold/habana/fastnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .ops import Linear, OutProductMean
from .triangle import PairStack

import habana_frameworks.torch.core as htcore

class Evoformer(nn.Module):

Expand Down Expand Up @@ -90,7 +91,6 @@ def forward(
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]

import habana_frameworks.torch.core as htcore
htcore.mark_step()

return m, z
Expand Down Expand Up @@ -220,7 +220,6 @@ def forward(

s = self.linear(m[..., 0, :, :])

import habana_frameworks.torch.core as htcore
htcore.mark_step()

return m, z, s
Expand Down Expand Up @@ -254,7 +253,6 @@ def forward(
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

import habana_frameworks.torch.core as htcore
htcore.mark_step()

dap_size = dist.get_world_size()
Expand Down
49 changes: 49 additions & 0 deletions fastfold/habana/fastnn/custom_op/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# CustomOp API Usage in PyTorch

This README provides an example of how to write custom PyTorch Ops using a TPC Kernel supported on an HPU device. For more details, refer to [PyTorch CustomOP API](https://docs.habana.ai/en/latest/PyTorch/PyTorch_CustomOp_API/page_index.html) documentation.

For further information on training deep learning models using Gaudi, refer to [developer.habana.ai](https://developer.habana.ai/resources/).

## Table of Contents

* [Model-References](../../../README.md)
* [Prerequisites](#prerequisites)
* [Content](#content)
* [Build and Run with Custom Kernels](#build-and-run-with-custom-kernels)
* [Important to Know](#important-to-know)
* [Applying CustomOps to a Real Training Model Example](#applying-customops-to-a-real-training-model-example)
* [Known Issues](#known-issues)


## Prerequisites

- A TPC kernel on which the HpuKernel will run. To write a CustomOp, you must define the TPC kernel that HpuKernel will run on first. This document provides the required steps for using the existing default TPC kernels `relu_fwd_f32`, `relu_bwd_f32` as we all as the custom kernel `custom_op::custom_relu` to implement CustomOp. For further information on how to write TPC kernels, refer to the [Habana Custom Kernel GitHub page](https://github.com/HabanaAI/Habana_Custom_Kernel).

- **habana-torch-plugin** Python package must be installed. Make sure to install by following the instructions detailed in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html).

## Content

- C++ file with **custom_op::fusedsoftmax**, **custom_op::fusedsoftmax_bias** definition and Kernel implementation on HPU:
- `fusedsoftmax` performs a fused softmax on input and mask.
- `fusedsoftmax_bias` performs a fused softmax on input, mask and bias
- `setup.py` file for building the solution:
- To compile to Op on Gaudi, run ```python setup.py build```.
- To compile to Op on Gaudi2, run ```python setup2.py build```.

- Python test to run and validate `fusedsoftmax` and `fusedsoftmax_bias`:
- ```python hpu_fusedsoftmax_test.py```

## Build and Run with Custom Kernels

To build and run `fused_softmax` and `fusedsoftmax_bias`, run the following:
```python setup.py build```

## Important to Know

This is an example of an Op implementing both forward and backward.
The forward and backward CustomOp is used for training the model by extending the [torch.autograd](https://pytorch.org/docs/stable/notes/extending.html) package.

## Known Issues

BF16 or HMP is not supported yet. To use CustomOp in topology, run FP32 variant only.

8 changes: 8 additions & 0 deletions fastfold/habana/fastnn/custom_op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################

from .fusedsoftmax import fused_softmax, fused_softmax_bias

__all__ = [fused_softmax, fused_softmax_bias]

81 changes: 81 additions & 0 deletions fastfold/habana/fastnn/custom_op/fusedsoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################

import torch
import os
import habana_frameworks.torch.core

custom_fusedsoftmax_op_lib_path = "./build/lib.linux-x86_64-3.8/hpu_fusedsoftmax.cpython-38-x86_64-linux-gnu.so"
my_dir = os.path.realpath(__file__)
my_len = my_dir.rfind('/')
base_dir = my_dir[:my_len]
torch.ops.load_library(os.path.join(base_dir, custom_fusedsoftmax_op_lib_path))

class FusedSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax(input, mask, dim)
ctx.y = tensor
ctx.dim = dim
return tensor

@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)
return grad_input, None, None

class FusedSoftmaxBiasFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax_bias(input, mask, bias, dim)
ctx.y = tensor
ctx.dim = dim
ctx.use_bias = False
if bias is not None:
ctx.use_bias = True
return tensor

@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)

grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=-4, keepdim=True)

return grad_input, None, grad_bias, None


ENABLE_OPT = True

def fused_softmax(input, mask, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape:
return FusedSoftmaxFunction.apply(input, mask, dim)
else:
input += mask
return torch.softmax(input, dim=dim)

def fused_softmax_bias(input, mask, bias, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape and input[..., :1, :, :, :].shape == bias.shape:
return FusedSoftmaxBiasFunction.apply(input, mask, bias, dim)
else:
input += mask
input += bias
return torch.softmax(input, dim=dim)
Loading