Skip to content

Commit da5fe1a

Browse files
ShengganLeoZhao-Intelhabanachina
authored
add support on habana platform (#131)
* add habana * add mask * fix mask in outer_product_mean * add dap * add hmp * merge training code * add chunk for inference * fix extra-msa stack for training * support ddp in training * fix inference bugs * code refactoring for habana * support hmp training * enable all inference and train on Gaudi/Gaudi2 with optimized perf with latest base (#139) * enable all inference and train on Gaudi/Gaudi2 with optimized perf * refine code to adapt new base * refine code to fix issues in code review Co-authored-by: habanachina <[email protected]> Co-authored-by: Leo Zhao <[email protected]> Co-authored-by: habanachina <[email protected]>
1 parent e9db72d commit da5fe1a

38 files changed

+3505
-82
lines changed

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44

55
[![](https://img.shields.io/badge/Paper-PDF-green?style=flat&logo=arXiv&logoColor=green)](https://arxiv.org/abs/2203.00854)
66
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)
7+
![](https://img.shields.io/badge/Habana-support-blue?style=flat&logo=intel&logoColor=blue)
78
![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
89
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold)](https://github.com/hpcaitech/FastFold/blob/main/LICENSE)
910

10-
Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters
11+
## News :triangular_flag_on_post:
12+
- [2023/01] Compatible with AlphaFold v2.3
13+
- [2023/01] Added support for inference and training of AlphaFold on [Intel Habana](https://habana.ai/) platform. For usage instructions, see [here](#Inference-or-Training-on-Intel-Habana).
14+
15+
<br>
16+
17+
Optimizing Protein Structure Prediction Model Training and Inference on Heterogeneous Clusters
1118

1219
FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.
1320

@@ -201,6 +208,17 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
201208
--kalign_binary_path `which kalign`
202209
```
203210

211+
### Inference or Training on Intel Habana
212+
213+
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments.
214+
215+
Once you have prepared your dataset and installed fastfold, you can use the following scripts:
216+
217+
```shell
218+
bash habana/inference.sh
219+
bash habana/train.sh
220+
```
221+
204222
## Performance Benchmark
205223

206224
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
@@ -237,3 +255,7 @@ Cite this paper, if you use FastFold in your research publication.
237255
primaryClass={cs.LG}
238256
}
239257
```
258+
259+
## Acknowledgments
260+
261+
We would like to extend our special thanks to the Intel Habana team for their support in providing us with technology and resources on the Habana platform.

fastfold/data/feature_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import torch
2222

23+
import fastfold.habana as habana
2324
from fastfold.data import input_pipeline, input_pipeline_multimer
2425

2526

@@ -91,19 +92,18 @@ def np_example_to_features(
9192
np_example=np_example, features=feature_names
9293
)
9394

94-
with torch.no_grad():
95-
if is_multimer:
96-
features = input_pipeline_multimer.process_tensors_from_config(
97-
tensor_dict,
98-
cfg.common,
99-
cfg[mode],
100-
)
101-
else:
102-
features = input_pipeline.process_tensors_from_config(
103-
tensor_dict,
104-
cfg.common,
105-
cfg[mode],
106-
)
95+
if is_multimer:
96+
input_pipeline_fn = input_pipeline_multimer.process_tensors_from_config
97+
else:
98+
input_pipeline_fn = input_pipeline.process_tensors_from_config
99+
100+
if habana.is_habana():
101+
from habana_frameworks.torch.hpex import hmp
102+
with torch.no_grad(), hmp.disable_casts():
103+
features = input_pipeline_fn(tensor_dict, cfg.common, cfg[mode])
104+
else:
105+
with torch.no_grad():
106+
features = input_pipeline_fn(tensor_dict, cfg.common, cfg[mode])
107107

108108
return {k: v for k, v in features.items()}
109109

@@ -118,7 +118,7 @@ def __init__(
118118
def process_features(
119119
self,
120120
raw_features: FeatureDict,
121-
mode: str = "train",
121+
mode: str = "train",
122122
is_multimer: bool = False,
123123
) -> FeatureDict:
124124
return np_example_to_features(

fastfold/habana/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
ENABLE_HABANA = False
2+
ENABLE_HMP = False
3+
4+
def enable_habana():
5+
global ENABLE_HABANA
6+
ENABLE_HABANA = True
7+
global ENABLE_LAZY_MODE
8+
ENABLE_LAZY_MODE = True
9+
import habana_frameworks.torch.core
10+
11+
def is_habana():
12+
global ENABLE_HABANA
13+
return ENABLE_HABANA
14+
15+
def enable_hmp():
16+
global ENABLE_HMP
17+
ENABLE_HMP = True
18+
19+
def is_hmp():
20+
global ENABLE_HMP
21+
return ENABLE_HMP
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
2+
gather, reduce, row_to_col, scatter)
3+
from .core import init_dist
4+
5+
__all__ = [
6+
'init_dist', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
7+
'col_to_row', 'row_to_col', 'All_to_All'
8+
]

fastfold/habana/distributed/comm.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.distributed as dist
5+
from torch import Tensor
6+
7+
from .core import (ensure_divisibility, get_tensor_model_parallel_group,
8+
get_tensor_model_parallel_rank,
9+
get_tensor_model_parallel_world_size)
10+
11+
12+
def divide(numerator, denominator):
13+
ensure_divisibility(numerator, denominator)
14+
return numerator // denominator
15+
16+
17+
def _reduce(tensor: Tensor) -> Tensor:
18+
if dist.get_world_size() == 1:
19+
return tensor
20+
21+
dist.all_reduce(tensor,
22+
op=dist.ReduceOp.SUM,
23+
group=get_tensor_model_parallel_group(),
24+
async_op=False)
25+
26+
return tensor
27+
28+
29+
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
30+
if get_tensor_model_parallel_world_size() == 1:
31+
return tensor
32+
33+
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
34+
tensor_list = torch.split(tensor, split_size, dim=dim)
35+
36+
output = tensor_list[get_tensor_model_parallel_rank()].contiguous()
37+
38+
return output
39+
40+
41+
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
42+
if get_tensor_model_parallel_world_size() == 1:
43+
return tensor
44+
45+
if dim == 1 and list(tensor.shape)[0] == 1:
46+
output_shape = list(tensor.shape)
47+
output_shape[1] *= get_tensor_model_parallel_world_size()
48+
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
49+
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
50+
dist.all_gather(list(tensor_list),
51+
tensor,
52+
group=get_tensor_model_parallel_group(),
53+
async_op=False)
54+
else:
55+
tensor_list = [
56+
torch.empty_like(tensor) for _ in range(get_tensor_model_parallel_world_size())
57+
]
58+
dist.all_gather(tensor_list,
59+
tensor,
60+
group=get_tensor_model_parallel_group(),
61+
async_op=False)
62+
output = torch.cat(tensor_list, dim=dim)
63+
64+
return output
65+
66+
67+
def copy(input: Tensor) -> Tensor:
68+
if torch.is_grad_enabled() and input.requires_grad:
69+
input = Copy.apply(input)
70+
return input
71+
72+
73+
class Copy(torch.autograd.Function):
74+
75+
@staticmethod
76+
def forward(ctx: "Copy", input: Tensor) -> Tensor:
77+
return input
78+
79+
@staticmethod
80+
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
81+
return _reduce(grad_output)
82+
83+
84+
def scatter(input: Tensor, dim: int = -1) -> Tensor:
85+
if torch.is_grad_enabled() and input.requires_grad:
86+
input = Scatter.apply(input, dim)
87+
else:
88+
input = _split(input, dim=dim)
89+
return input
90+
91+
92+
class Scatter(torch.autograd.Function):
93+
94+
@staticmethod
95+
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
96+
ctx.save_for_backward(torch.tensor([dim]))
97+
return _split(input, dim=dim)
98+
99+
@staticmethod
100+
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
101+
dim, = ctx.saved_tensors
102+
return _gather(grad_output, dim=int(dim)), None
103+
104+
105+
def reduce(input: Tensor) -> Tensor:
106+
if torch.is_grad_enabled() and input.requires_grad:
107+
input = Reduce.apply(input)
108+
else:
109+
input = _reduce(input)
110+
return input
111+
112+
113+
class Reduce(torch.autograd.Function):
114+
115+
@staticmethod
116+
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
117+
return _reduce(input)
118+
119+
@staticmethod
120+
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
121+
return grad_output
122+
123+
124+
def gather(input: Tensor, dim: int = -1) -> Tensor:
125+
if torch.is_grad_enabled() and input.requires_grad:
126+
input = Gather.apply(input, dim)
127+
else:
128+
input = _gather(input, dim=dim)
129+
return input
130+
131+
132+
class Gather(torch.autograd.Function):
133+
134+
@staticmethod
135+
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
136+
ctx.save_for_backward(torch.tensor([dim]))
137+
return _gather(input, dim=dim)
138+
139+
@staticmethod
140+
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
141+
dim, = ctx.saved_tensors
142+
return _split(grad_output, dim=int(dim)), None
143+
144+
145+
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
146+
if dist.get_world_size() == 1:
147+
return tensor
148+
149+
tensor = tensor.transpose(in_dim, 0).contiguous()
150+
151+
output = torch.empty_like(tensor)
152+
dist.all_to_all_single(output, tensor, group=get_tensor_model_parallel_group())
153+
154+
output = output.transpose(in_dim, 0).contiguous()
155+
156+
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=in_dim)
157+
158+
return torch.cat(tensor_list, dim=out_dim)
159+
160+
161+
def col_to_row(input_: Tensor) -> Tensor:
162+
if torch.is_grad_enabled() and input_.requires_grad:
163+
input_ = All_to_All.apply(input_, 1, 2)
164+
else:
165+
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
166+
return input_
167+
168+
169+
def row_to_col(input_: Tensor) -> Tensor:
170+
if torch.is_grad_enabled() and input_.requires_grad:
171+
input_ = All_to_All.apply(input_, 2, 1)
172+
else:
173+
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
174+
return input_
175+
176+
177+
class All_to_All(torch.autograd.Function):
178+
179+
@staticmethod
180+
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
181+
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
182+
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
183+
184+
@staticmethod
185+
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
186+
saved_tensors = ctx.saved_tensors[0]
187+
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
188+
out_dim=int(saved_tensors[0])), None, None

0 commit comments

Comments
 (0)