1
+ import os
2
+ import pytest
3
+ import torch
4
+ import pickle
5
+ import torch .multiprocessing as mp
6
+ from functools import partial
7
+ import colossalai
8
+ from fastfold .model .hub import AlphaFold
9
+ from fastfold .config import model_config
10
+ from fastfold .model .fastnn import set_chunk_size
11
+ from fastfold .utils .inject_fastnn import inject_fastnn
12
+ from fastfold .utils .test_utils import get_train_data_path
13
+ from fastfold .model .hub .loss import AlphaFoldLoss
14
+ from fastfold .utils .tensor_utils import tensor_tree_map
15
+ from fastfold .utils .test_utils import set_seed
16
+
17
+
18
+ def get_param_and_grad (model ):
19
+ params = dict ()
20
+ grads = dict ()
21
+ for name , param in model .named_parameters ():
22
+ params [name ] = param .detach ().clone ()
23
+ grads [name ] = param .grad .detach ().clone ()
24
+
25
+ return params , grads
26
+
27
+
28
+ @pytest .fixture (scope = "module" )
29
+ def get_openfold_state ():
30
+ config = model_config ('initial_training' , train = True )
31
+ config .globals .inplace = False
32
+ set_seed (42 )
33
+ model = AlphaFold (config )
34
+ model .train ().cuda ()
35
+ criterion = AlphaFoldLoss (config .loss )
36
+ optimizer = torch .optim .Adam (model .parameters (), lr = 1e-3 , eps = 1e-8 )
37
+ batch = pickle .load (open (get_train_data_path (), 'rb' ))
38
+ set_seed (42 )
39
+ batch = {k : torch .as_tensor (v ).cuda () for k , v in batch .items ()}
40
+ out = model (batch )
41
+ batch = tensor_tree_map (lambda t : t [..., - 1 ], batch )
42
+ loss , _ = criterion (out , batch , True )
43
+ optimizer .zero_grad ()
44
+ set_seed (42 )
45
+ loss .backward ()
46
+ optimizer .step ()
47
+ of_params , of_grads = get_param_and_grad (model )
48
+ return of_params , of_grads
49
+
50
+
51
+ @pytest .mark .skipif (torch .cuda .mem_get_info (0 )[1 ] < 4e10 , reason = "Not enough cuda memory" )
52
+ @pytest .mark .parametrize ('world_size' , [1 ])
53
+ def test_state_dict (world_size , get_openfold_state ):
54
+ run_func = partial (run_dist , world_size = world_size , model = get_openfold_state )
55
+ mp .spawn (run_func , nprocs = world_size )
56
+
57
+
58
+ def run_dist (rank , world_size , model ):
59
+ os .environ ['RANK' ] = str (rank )
60
+ os .environ ['LOCAL_RANK' ] = str (rank )
61
+ os .environ ['WORLD_SIZE' ] = str (world_size )
62
+ colossalai .launch (config = dict (parallel = dict (tensor = dict (size = world_size ))), rank = rank , world_size = world_size ,
63
+ host = 'localhost' , port = 10101 , backend = 'nccl' )
64
+ train (world_size , model )
65
+
66
+
67
+ def train (world_size , get_openfold_state ):
68
+
69
+ of_params , of_grads = get_openfold_state
70
+ config = model_config ('initial_training' , train = True )
71
+ config .globals .inplace = False
72
+ set_seed (42 )
73
+ model = AlphaFold (config )
74
+ model = inject_fastnn (model )
75
+ model .train ().cuda ()
76
+ criterion = AlphaFoldLoss (config .loss )
77
+ optimizer = torch .optim .Adam (model .parameters (), lr = 1e-3 , eps = 1e-8 )
78
+ set_chunk_size (None )
79
+ batch = pickle .load (open (get_train_data_path (), 'rb' ))
80
+ batch = {k : torch .as_tensor (v ).cuda () for k , v in batch .items ()}
81
+ set_seed (42 )
82
+ out = model (batch )
83
+ batch = tensor_tree_map (lambda t : t [..., - 1 ], batch )
84
+ loss , _ = criterion (out , batch , True )
85
+ optimizer .zero_grad ()
86
+ set_seed (42 )
87
+ loss .backward ()
88
+ optimizer .step ()
89
+ ff_params , ff_grads = get_param_and_grad (model )
90
+
91
+ params_dif = 0
92
+ grads_dif = 0
93
+ for name in ff_params .keys ():
94
+ # the modules' names in fastfold and openfold are not equal
95
+ # it leads some differences on the order of the parameters
96
+ # it's not a hard problem to solve
97
+ # but check the params and grads of the same part may be just enough
98
+ if name not in of_params .keys ():
99
+ continue
100
+
101
+ dif = torch .max (torch .abs (ff_params [name ] - of_params [name ]))
102
+ if dif > params_dif :
103
+ params_dif = dif
104
+ dif = torch .max (torch .abs (ff_grads [name ] - of_grads [name ]))
105
+ if dif > grads_dif :
106
+ grads_dif = dif
107
+ assert params_dif < 1e-3 and grads_dif < 5e-3 , f"Test failed at world size: { world_size } , \
108
+ the param dif is { params_dif } , the grad diff is { grads_dif } "
109
+
110
+
111
+ if __name__ == '__main__' :
112
+ test_state_dict (1 , None , None )
0 commit comments