Skip to content

Commit 95150c3

Browse files
authored
fix distributed sampler for ddp and add dap (#141)
* fix sampler of ddp and add dap * add end to end unit test for training * modify unit test datapath * lower the bar for cuda kernel * skip unit test of end-to-end train on 3080
1 parent 19ce840 commit 95150c3

File tree

4 files changed

+148
-9
lines changed

4 files changed

+148
-9
lines changed

fastfold/data/data_modules.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import ml_collections as mlc
2323
import torch
24-
24+
from colossalai.utils import is_using_ddp
2525
from fastfold.data import (
2626
data_pipeline,
2727
feature_pipeline,
@@ -384,8 +384,8 @@ def __call__(self, raw_prots):
384384

385385

386386
class OpenFoldDataLoader(torch.utils.data.DataLoader):
387-
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
388-
super().__init__(*args, **kwargs)
387+
def __init__(self, dataset, config, stage="train", generator=None, **kwargs):
388+
super().__init__(dataset, **kwargs)
389389
self.config = config
390390
self.stage = stage
391391

@@ -604,28 +604,36 @@ def TrainDataLoader(
604604
generator = generator.manual_seed(batch_seed)
605605

606606
train_batch_collator = OpenFoldBatchCollator(config, "train")
607+
train_sampler = None
608+
if is_using_ddp():
609+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
607610
train_dataset.reroll()
608611
train_dataloader = OpenFoldDataLoader(
609-
train_dataset,
612+
dataset=train_dataset,
610613
config=config,
611614
stage="train",
612615
generator=generator,
613616
batch_size=config.data_module.data_loaders.batch_size,
614617
num_workers=config.data_module.data_loaders.num_workers,
615618
collate_fn=train_batch_collator,
619+
sampler=train_sampler,
616620
)
617621

618622
test_dataloader = None
619623
if test_dataset is not None:
624+
test_sampler = None
625+
if is_using_ddp():
626+
test_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
620627
test_batch_collator = OpenFoldBatchCollator(config, "test")
621628
test_dataloader = OpenFoldDataLoader(
622-
train_dataset,
629+
dataset=test_dataset,
623630
config=config,
624631
stage="test",
625632
generator=generator,
626633
batch_size=config.data_module.data_loaders.batch_size,
627634
num_workers=config.data_module.data_loaders.num_workers,
628635
collate_fn=test_batch_collator,
636+
sampler=test_sampler,
629637
)
630638

631639
return train_dataloader, test_dataloader

fastfold/utils/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2+
import random
23

4+
import torch
5+
import numpy as np
36

47
def get_param_path():
58
# develop
@@ -15,3 +18,16 @@ def get_data_path():
1518
return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
1619
# test
1720
return '/data/scratch/fastfold/mono_batch.pkl'
21+
22+
23+
def get_train_data_path():
24+
return '/data/scratch/fastfold/std_train_batch.pkl'
25+
26+
def set_seed(seed):
27+
random.seed(seed)
28+
os.environ['PYTHONHASHSEED'] = str(seed)
29+
np.random.seed(seed)
30+
torch.manual_seed(seed)
31+
torch.cuda.manual_seed(seed)
32+
torch.backends.cudnn.deterministic = True
33+
torch.backends.cudnn.benchmark = False

tests/test_train.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import pytest
3+
import torch
4+
import pickle
5+
import torch.multiprocessing as mp
6+
from functools import partial
7+
import colossalai
8+
from fastfold.model.hub import AlphaFold
9+
from fastfold.config import model_config
10+
from fastfold.model.fastnn import set_chunk_size
11+
from fastfold.utils.inject_fastnn import inject_fastnn
12+
from fastfold.utils.test_utils import get_train_data_path
13+
from fastfold.model.hub.loss import AlphaFoldLoss
14+
from fastfold.utils.tensor_utils import tensor_tree_map
15+
from fastfold.utils.test_utils import set_seed
16+
17+
18+
def get_param_and_grad(model):
19+
params = dict()
20+
grads = dict()
21+
for name, param in model.named_parameters():
22+
params[name] = param.detach().clone()
23+
grads[name] = param.grad.detach().clone()
24+
25+
return params, grads
26+
27+
28+
@pytest.fixture(scope="module")
29+
def get_openfold_state():
30+
config = model_config('initial_training', train=True)
31+
config.globals.inplace = False
32+
set_seed(42)
33+
model = AlphaFold(config)
34+
model.train().cuda()
35+
criterion = AlphaFoldLoss(config.loss)
36+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
37+
batch = pickle.load(open(get_train_data_path(), 'rb'))
38+
set_seed(42)
39+
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
40+
out = model(batch)
41+
batch = tensor_tree_map(lambda t: t[..., -1], batch)
42+
loss, _ = criterion(out, batch, True)
43+
optimizer.zero_grad()
44+
set_seed(42)
45+
loss.backward()
46+
optimizer.step()
47+
of_params, of_grads = get_param_and_grad(model)
48+
return of_params, of_grads
49+
50+
51+
@pytest.mark.skipif(torch.cuda.mem_get_info(0)[1] < 4e10, reason="Not enough cuda memory")
52+
@pytest.mark.parametrize('world_size', [1])
53+
def test_state_dict(world_size, get_openfold_state):
54+
run_func = partial(run_dist, world_size=world_size, model=get_openfold_state)
55+
mp.spawn(run_func, nprocs=world_size)
56+
57+
58+
def run_dist(rank, world_size, model):
59+
os.environ['RANK'] = str(rank)
60+
os.environ['LOCAL_RANK'] = str(rank)
61+
os.environ['WORLD_SIZE'] = str(world_size)
62+
colossalai.launch(config=dict(parallel=dict(tensor=dict(size=world_size))), rank=rank, world_size=world_size,
63+
host='localhost', port=10101, backend='nccl')
64+
train(world_size, model)
65+
66+
67+
def train(world_size, get_openfold_state):
68+
69+
of_params, of_grads = get_openfold_state
70+
config = model_config('initial_training', train=True)
71+
config.globals.inplace = False
72+
set_seed(42)
73+
model = AlphaFold(config)
74+
model = inject_fastnn(model)
75+
model.train().cuda()
76+
criterion = AlphaFoldLoss(config.loss)
77+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
78+
set_chunk_size(None)
79+
batch = pickle.load(open(get_train_data_path(), 'rb'))
80+
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
81+
set_seed(42)
82+
out = model(batch)
83+
batch = tensor_tree_map(lambda t: t[..., -1], batch)
84+
loss, _ = criterion(out, batch, True)
85+
optimizer.zero_grad()
86+
set_seed(42)
87+
loss.backward()
88+
optimizer.step()
89+
ff_params, ff_grads = get_param_and_grad(model)
90+
91+
params_dif = 0
92+
grads_dif = 0
93+
for name in ff_params.keys():
94+
# the modules' names in fastfold and openfold are not equal
95+
# it leads some differences on the order of the parameters
96+
# it's not a hard problem to solve
97+
# but check the params and grads of the same part may be just enough
98+
if name not in of_params.keys():
99+
continue
100+
101+
dif = torch.max(torch.abs(ff_params[name] - of_params[name]))
102+
if dif > params_dif:
103+
params_dif = dif
104+
dif = torch.max(torch.abs(ff_grads[name] - of_grads[name]))
105+
if dif > grads_dif:
106+
grads_dif = dif
107+
assert params_dif < 1e-3 and grads_dif < 5e-3, f"Test failed at world size: {world_size}, \
108+
the param dif is {params_dif}, the grad diff is {grads_dif}"
109+
110+
111+
if __name__ == '__main__':
112+
test_state_dict(1, None, None)

train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import colossalai
66
from colossalai.logging import disable_existing_loggers, get_dist_logger
7-
from colossalai.core import global_context as gpc
87
from colossalai.nn.optimizer import HybridAdam
98

109
from fastfold.config import model_config
@@ -13,7 +12,6 @@
1312
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
1413
from fastfold.utils.tensor_utils import tensor_tree_map
1514
from fastfold.utils.validation_utils import compute_validation_metrics
16-
1715
#import logging
1816
#logging.disable(logging.WARNING)
1917
import torch.multiprocessing
@@ -153,14 +151,19 @@ def main():
153151
"--save_ckpt_interval", type=int, default=1,
154152
help="The interval epochs of save checkpoint"
155153
)
154+
parser.add_argument(
155+
"--dap_size", type=int, default=1,
156+
help="DAP size, recommended as 1 - nproc_per_node"
157+
)
156158

157159
args = parser.parse_args()
158160
random.seed(args.seed)
159161
np.random.seed(args.seed)
160162
torch.manual_seed(args.seed)
161163
torch.cuda.manual_seed_all(args.seed)
162164
if args.from_torch:
163-
colossalai.launch_from_torch(config=dict(torch_ddp=dict(static_graph=True)))
165+
colossalai.launch_from_torch(config=dict(parallel=dict(tensor=dict(size=args.dap_size)),
166+
torch_ddp=dict(static_graph=True)))
164167
disable_existing_loggers()
165168
logger = get_dist_logger()
166169
logger.log_to_file(args.log_path)
@@ -227,7 +230,7 @@ def main():
227230
loss, loss_breakdown = engine.criterion(
228231
output, batch, _return_breakdown=True)
229232
if (i+1) % args.log_interval == 0:
230-
logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*args.train_epoch_len+i+1},' +
233+
logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*len(train_dataloader)+i+1},' +
231234
f' Loss:{log_loss(loss_breakdown, batch, output)}', ranks=[0])
232235
engine.zero_grad()
233236
engine.backward(loss)

0 commit comments

Comments
 (0)