Skip to content

Commit ceee81d

Browse files
Gy-LuShenggan
andauthored
support training, fix some inplace func of nn (#118)
* support training, fix some inplace func of nn * fix some merge issue Co-authored-by: shenggan <[email protected]>
1 parent 164f677 commit ceee81d

File tree

9 files changed

+350
-10
lines changed

9 files changed

+350
-10
lines changed

fastfold/model/fastnn/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def forward(
391391
for i in range(0, t.shape[0]):
392392
t[i] = self.layer_norm(t[i])
393393
else:
394-
t = self.layer_norm(t[i])
394+
t = self.layer_norm(t)
395395
return t
396396

397397
def inplace(

fastfold/model/hub/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .alphafold import AlphaFold
2+
from .lr_scheduler import AlphaFoldLRScheduler
3+
from .loss import AlphaFoldLoss
24

3-
__all__ = ["AlphaFold"]
5+
__all__ = ["AlphaFold", "AlphaFoldLRScheduler", "AlphaFoldLoss"]
File renamed without changes.

fastfold/model/hub/lr_scheduler.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2022 HPC-AI Tech Inc
2+
# Copyright 2021 AlQuraishi Laboratory
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
19+
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
20+
""" Implements the learning rate schedule defined in the AlphaFold 2
21+
supplement. A linear warmup is followed by a plateau at the maximum
22+
learning rate and then exponential decay.
23+
24+
Note that the initial learning rate of the optimizer in question is
25+
ignored; use this class' base_lr parameter to specify the starting
26+
point of the warmup.
27+
"""
28+
def __init__(self,
29+
optimizer,
30+
last_epoch: int = -1,
31+
verbose: bool = False,
32+
base_lr: float = 0.,
33+
max_lr: float = 0.001,
34+
warmup_no_steps: int = 1000,
35+
start_decay_after_n_steps: int = 50000,
36+
decay_every_n_steps: int = 50000,
37+
decay_factor: float = 0.95,
38+
):
39+
step_counts = {
40+
"warmup_no_steps": warmup_no_steps,
41+
"start_decay_after_n_steps": start_decay_after_n_steps,
42+
}
43+
44+
for k,v in step_counts.items():
45+
if(v < 0):
46+
raise ValueError(f"{k} must be nonnegative")
47+
48+
if(warmup_no_steps > start_decay_after_n_steps):
49+
raise ValueError(
50+
"warmup_no_steps must not exceed start_decay_after_n_steps"
51+
)
52+
53+
self.optimizer = optimizer
54+
self.last_epoch = last_epoch
55+
self.verbose = verbose
56+
self.base_lr = base_lr
57+
self.max_lr = max_lr
58+
self.warmup_no_steps = warmup_no_steps
59+
self.start_decay_after_n_steps = start_decay_after_n_steps
60+
self.decay_every_n_steps = decay_every_n_steps
61+
self.decay_factor = decay_factor
62+
63+
super(AlphaFoldLRScheduler, self).__init__(
64+
optimizer,
65+
last_epoch=last_epoch,
66+
verbose=verbose,
67+
)
68+
69+
def state_dict(self):
70+
state_dict = {
71+
k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
72+
}
73+
74+
return state_dict
75+
76+
def load_state_dict(self, state_dict):
77+
self.__dict__.update(state_dict)
78+
79+
def get_lr(self):
80+
if(not self._get_lr_called_within_step):
81+
raise RuntimeError(
82+
"To get the last learning rate computed by the scheduler, use "
83+
"get_last_lr()"
84+
)
85+
86+
step_no = self.last_epoch
87+
88+
if(step_no <= self.warmup_no_steps):
89+
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
90+
elif(step_no > self.start_decay_after_n_steps):
91+
steps_since_decay = step_no - self.start_decay_after_n_steps
92+
exp = (steps_since_decay // self.decay_every_n_steps) + 1
93+
lr = self.max_lr * (self.decay_factor ** exp)
94+
else: # plateau
95+
lr = self.max_lr
96+
97+
return [lr for group in self.optimizer.param_groups]

fastfold/model/nn/dropout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5656
shape[bd] = 1
5757
mask = x.new_ones(shape)
5858
mask = self.dropout(mask)
59-
x *= mask
59+
x = x * mask
6060
return x
6161

6262

fastfold/model/nn/evoformer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,7 @@ def __init__(self,
263263
inf=inf,
264264
eps=eps,
265265
)
266-
267-
self.outer_product_mean = OuterProductMean(
268-
c_m,
269-
c_z,
270-
c_hidden_opm,
271-
)
266+
272267
self.is_multimer = is_multimer
273268

274269
def forward(self,

fastfold/model/nn/heads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818

1919
from fastfold.model.nn.primitives import Linear, LayerNorm
20-
from fastfold.model.loss import (
20+
from fastfold.model.hub.loss import (
2121
compute_plddt,
2222
compute_tm,
2323
compute_predicted_aligned_error,

train.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import random
2+
import torch
3+
import numpy as np
4+
import colossalai
5+
from colossalai.logging import disable_existing_loggers, get_dist_logger
6+
from colossalai.core import global_context as gpc
7+
from colossalai.nn.optimizer import HybridAdam
8+
9+
from tqdm import tqdm
10+
11+
from fastfold.config import model_config
12+
from fastfold.model.hub import AlphaFold, AlphaFoldLRScheduler, AlphaFoldLoss
13+
from fastfold.utils.inject_fastnn import inject_fastnn
14+
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
15+
from fastfold.utils.tensor_utils import tensor_tree_map
16+
17+
import logging
18+
logging.disable(logging.WARNING)
19+
import torch.multiprocessing
20+
torch.multiprocessing.set_sharing_strategy('file_system')
21+
22+
def main():
23+
parser = colossalai.get_default_parser()
24+
parser.add_argument('--from_torch', default=False, action='store_true')
25+
parser.add_argument(
26+
"--template_mmcif_dir", type=str,
27+
help="Directory containing mmCIF files to search for templates"
28+
)
29+
parser.add_argument(
30+
"--max_template_date", type=str,
31+
help='''Cutoff for all templates. In training mode, templates are also
32+
filtered by the release date of the target'''
33+
)
34+
parser.add_argument(
35+
"--train_data_dir", type=str,
36+
help="Directory containing training mmCIF files"
37+
)
38+
parser.add_argument(
39+
"--train_alignment_dir", type=str,
40+
help="Directory containing precomputed training alignments"
41+
)
42+
parser.add_argument(
43+
"--train_chain_data_cache_path", type=str, default=None,
44+
)
45+
parser.add_argument(
46+
"--distillation_data_dir", type=str, default=None,
47+
help="Directory containing training PDB files"
48+
)
49+
parser.add_argument(
50+
"--distillation_alignment_dir", type=str, default=None,
51+
help="Directory containing precomputed distillation alignments"
52+
)
53+
parser.add_argument(
54+
"--distillation_chain_data_cache_path", type=str, default=None,
55+
)
56+
parser.add_argument(
57+
"--val_data_dir", type=str, default=None,
58+
help="Directory containing validation mmCIF files"
59+
)
60+
parser.add_argument(
61+
"--val_alignment_dir", type=str, default=None,
62+
help="Directory containing precomputed validation alignments"
63+
)
64+
parser.add_argument(
65+
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
66+
help="Path to the kalign binary"
67+
)
68+
parser.add_argument(
69+
"--train_filter_path", type=str, default=None,
70+
help='''Optional path to a text file containing names of training
71+
examples to include, one per line. Used to filter the training
72+
set'''
73+
)
74+
parser.add_argument(
75+
"--distillation_filter_path", type=str, default=None,
76+
help="""See --train_filter_path"""
77+
)
78+
parser.add_argument(
79+
"--obsolete_pdbs_file_path", type=str, default=None,
80+
help="""Path to obsolete.dat file containing list of obsolete PDBs and
81+
their replacements."""
82+
)
83+
parser.add_argument(
84+
"--template_release_dates_cache_path", type=str, default=None,
85+
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
86+
files."""
87+
)
88+
parser.add_argument(
89+
"--train_epoch_len", type=int, default=10000,
90+
help=(
91+
"The virtual length of each training epoch. Stochastic filtering "
92+
"of training data means that training datasets have no "
93+
"well-defined length. This virtual length affects frequency of "
94+
"validation & checkpointing (by default, one of each per epoch)."
95+
)
96+
)
97+
parser.add_argument(
98+
"--_alignment_index_path", type=str, default=None,
99+
help="Training alignment index. See the README for instructions."
100+
)
101+
parser.add_argument(
102+
"--config_preset", type=str, default="initial_training",
103+
help=(
104+
'Config setting. Choose e.g. "initial_training", "finetuning", '
105+
'"model_1", etc. By default, the actual values in the config are '
106+
'used.'
107+
)
108+
)
109+
parser.add_argument(
110+
"--_distillation_structure_index_path", type=str, default=None,
111+
)
112+
parser.add_argument(
113+
"--distillation_alignment_index_path", type=str, default=None,
114+
help="Distillation alignment index. See the README for instructions."
115+
)
116+
parser.add_argument(
117+
"--seed", type=int, default=42,
118+
help="Random seed"
119+
)
120+
121+
args = parser.parse_args()
122+
random.seed(args.seed)
123+
np.random.seed(args.seed)
124+
torch.manual_seed(args.seed)
125+
torch.cuda.manual_seed_all(args.seed)
126+
if args.from_torch:
127+
colossalai.launch_from_torch(config=dict(torch_ddp=dict(static_graph=True)))
128+
disable_existing_loggers()
129+
logger = get_dist_logger()
130+
131+
config = model_config(args.config_preset, train=True)
132+
config.globals.inplace = False
133+
model = AlphaFold(config)
134+
model = inject_fastnn(model)
135+
136+
137+
train_dataset, test_dataset = SetupTrainDataset(
138+
config=config.data,
139+
template_mmcif_dir=args.template_mmcif_dir,
140+
max_template_date=args.max_template_date,
141+
train_data_dir=args.train_data_dir,
142+
train_alignment_dir=args.train_alignment_dir,
143+
train_chain_data_cache_path=args.train_chain_data_cache_path,
144+
distillation_data_dir=args.distillation_data_dir,
145+
distillation_alignment_dir=args.distillation_alignment_dir,
146+
distillation_chain_data_cache_path=args.distillation_chain_data_cache_path,
147+
val_data_dir=args.val_data_dir,
148+
val_alignment_dir=args.val_alignment_dir,
149+
kalign_binary_path=args.kalign_binary_path,
150+
# train_mapping_path=args.train_mapping_path,
151+
# distillation_mapping_path=args.distillation_mapping_path,
152+
obsolete_pdbs_file_path=args.obsolete_pdbs_file_path,
153+
template_release_dates_cache_path=args.template_release_dates_cache_path,
154+
train_epoch_len=args.train_epoch_len,
155+
_alignment_index_path=args._alignment_index_path,
156+
)
157+
158+
train_dataloader, test_dataloader = TrainDataLoader(
159+
config=config.data,
160+
train_dataset=train_dataset,
161+
test_dataset=test_dataset,
162+
batch_seed=args.seed,
163+
)
164+
165+
166+
criterion = AlphaFoldLoss(config.loss)
167+
168+
optimizer = HybridAdam(model.parameters(), lr=1e-3, eps=1e-8)
169+
170+
lr_scheduler = AlphaFoldLRScheduler(optimizer)
171+
172+
173+
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
174+
model=model,
175+
optimizer=optimizer,
176+
criterion=criterion,
177+
lr_scheduler=lr_scheduler,
178+
train_dataloader=train_dataloader,
179+
test_dataloader=test_dataloader,
180+
)
181+
182+
for epoch in range(200):
183+
engine.train()
184+
if gpc.get_global_rank() == 0:
185+
train_dataloader = tqdm(train_dataloader)
186+
for batch in train_dataloader:
187+
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
188+
engine.zero_grad()
189+
output = engine(batch)
190+
batch = tensor_tree_map(lambda t: t[..., -1], batch)
191+
loss, loss_breakdown = engine.criterion(
192+
output, batch, _return_breakdown=True)
193+
if gpc.get_global_rank() == 0:
194+
train_dataloader.set_postfix(loss=float(loss))
195+
engine.backward(loss)
196+
engine.step()
197+
lr_scheduler.step()
198+
199+
if test_dataloader is not None:
200+
engine.eval()
201+
if gpc.get_global_rank() == 0:
202+
train_dataloader = tqdm(train_dataloader)
203+
for batch in test_dataloader:
204+
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
205+
with torch.no_grad():
206+
output = engine(batch)
207+
batch = tensor_tree_map(lambda t: t[..., -1], batch)
208+
_, loss_breakdown = engine.criterion(
209+
output, batch, _return_breakdown=True)
210+
if gpc.get_global_rank() == 0:
211+
train_dataloader.set_postfix(loss=float(loss))
212+
213+
214+
215+
if __name__ == "__main__":
216+
main()

train.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
DATA_DIR=/path/to/data
2+
PROJECT_DIR=/path/to/project
3+
4+
gpus_per_node=2
5+
nnodes=1
6+
7+
max_template_date=2021-10-10
8+
9+
train_data_dir=${DATA_DIR}/mmcif_dir # specify the dir contains *.cif or *.pdb
10+
train_alignment_dir=${DATA_DIR}/alignment_dir # a dir to save template and features.pkl of training sequence
11+
mkdir -p ${train_alignment_dir}
12+
13+
# val_data_dir=${PROJECT_DIR}/dataset/val_pdb
14+
# val_alignment_dir=${PROJECT_DIR}/dataset/alignment_val_pdb # a dir to save template and features.pkl of vld sequence
15+
16+
template_mmcif_dir=${DATA_DIR}/data/pdb_mmcif/mmcif_files
17+
template_release_dates_cache_path=${DATA_DIR}/mmcif_cache.json # a cache used to pre-filter templates
18+
train_chain_data_cache_path=${DATA_DIR}/chain_data_cache.json # a separate chain-level cache with data used for training-time data filtering
19+
20+
train_epoch_len=10000 # virtual length of each training epoch, which affects frequency of validation & checkpointing
21+
22+
torchrun --standalone --nproc_per_node ${gpus_per_node} --nnodes ${nnodes} train.py \
23+
--from_torch \
24+
--template_mmcif_dir=${template_mmcif_dir} \
25+
--max_template_date=${max_template_date} \
26+
--train_data_dir=${train_data_dir} \
27+
--train_alignment_dir=${train_alignment_dir} \
28+
--train_chain_data_cache_path=${train_chain_data_cache_path} \
29+
--template_release_dates_cache_path=${template_release_dates_cache_path} \
30+
--train_epoch_len=${train_epoch_len} \

0 commit comments

Comments
 (0)