|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | + |
| 7 | +from fastfold.distributed import init_shadowcore |
| 8 | +from fastfold.model import Evoformer |
| 9 | + |
| 10 | + |
| 11 | +def main(): |
| 12 | + |
| 13 | + parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark') |
| 14 | + parser.add_argument("--dap-size", default=1, type=int) |
| 15 | + parser.add_argument('--batch-size', default=1, type=int, help='batch size') |
| 16 | + parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input') |
| 17 | + parser.add_argument('--res-length', |
| 18 | + default=256, |
| 19 | + type=int, |
| 20 | + help='Start Range of Number of Sequences') |
| 21 | + parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute') |
| 22 | + parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') |
| 23 | + parser.add_argument('--layers', |
| 24 | + default=12, |
| 25 | + type=int, |
| 26 | + help='Attention Layers to Execute to Gain CPU/GPU Time Overlap') |
| 27 | + parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension') |
| 28 | + parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension') |
| 29 | + parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads') |
| 30 | + parser.add_argument('--openfold', |
| 31 | + action='store_true', |
| 32 | + help='torch.nn.MultitheadAttention Version.') |
| 33 | + parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.') |
| 34 | + parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.') |
| 35 | + |
| 36 | + args = parser.parse_args() |
| 37 | + |
| 38 | + args.distributed = False |
| 39 | + if 'WORLD_SIZE' in os.environ: |
| 40 | + args.distributed = int(os.environ['WORLD_SIZE']) > 1 |
| 41 | + |
| 42 | + args.local_rank = int(os.environ['LOCAL_RANK']) |
| 43 | + |
| 44 | + torch.cuda.set_device(args.local_rank) |
| 45 | + torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| 46 | + args.world_size = torch.distributed.get_world_size() |
| 47 | + args.global_rank = torch.distributed.get_rank() |
| 48 | + print( |
| 49 | + 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' |
| 50 | + % (args.global_rank, args.world_size)) |
| 51 | + init_shadowcore(args.tensor_model_parallel_size) |
| 52 | + |
| 53 | + precision = torch.bfloat16 |
| 54 | + if args.tensor_model_parallel_size > 1: |
| 55 | + # (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch |
| 56 | + precision = torch.float16 |
| 57 | + |
| 58 | + if not torch.cuda.is_available(): |
| 59 | + raise NotImplementedError('Running on CPU is not supported') |
| 60 | + |
| 61 | + torch.manual_seed(42) |
| 62 | + if torch.cuda.is_available(): |
| 63 | + torch.cuda.manual_seed_all(42) |
| 64 | + |
| 65 | + if args.openfold: |
| 66 | + from openfold.model.evoformer import EvoformerBlock |
| 67 | + |
| 68 | + class OpenFoldEvoformer(nn.Module): |
| 69 | + |
| 70 | + def __init__(self, d_node, d_pair): |
| 71 | + super(OpenFoldEvoformer, self).__init__() |
| 72 | + self.d_node = d_node |
| 73 | + self.d_pair = d_pair |
| 74 | + |
| 75 | + self.c_hidden_msa_att = int(d_node / 8) |
| 76 | + self.c_hidden_pair_att = int(d_pair / 8) |
| 77 | + |
| 78 | + self.EvoformerBlock = EvoformerBlock(c_m=d_node, |
| 79 | + c_z=d_pair, |
| 80 | + c_hidden_msa_att=self.c_hidden_msa_att, |
| 81 | + c_hidden_opm=self.c_hidden_msa_att, |
| 82 | + c_hidden_mul=self.d_pair, |
| 83 | + c_hidden_pair_att=self.c_hidden_pair_att, |
| 84 | + no_heads_msa=8, |
| 85 | + no_heads_pair=4, |
| 86 | + transition_n=4, |
| 87 | + msa_dropout=0.15, |
| 88 | + pair_dropout=0.25, |
| 89 | + inf=1e9, |
| 90 | + eps=1e-10) |
| 91 | + |
| 92 | + def forward(self, node, pair, node_mask, pair_mask): |
| 93 | + node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask) |
| 94 | + return node, pair |
| 95 | + |
| 96 | + attn_layers = [] |
| 97 | + for idx in range(0, args.layers): |
| 98 | + if args.openfold: |
| 99 | + attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz)) |
| 100 | + else: |
| 101 | + attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz)) |
| 102 | + attn_layers[idx].cuda() |
| 103 | + attn_layers[idx].to(dtype=precision) |
| 104 | + |
| 105 | + start_evt_fwd = [] |
| 106 | + start_evt_bwd = [] |
| 107 | + stop_evt_bwd = [] |
| 108 | + for recorded_trial in range(0, args.trials): |
| 109 | + start_evt_fwd.append(torch.cuda.Event(enable_timing=True)) |
| 110 | + start_evt_bwd.append(torch.cuda.Event(enable_timing=True)) |
| 111 | + stop_evt_bwd.append(torch.cuda.Event(enable_timing=True)) |
| 112 | + |
| 113 | + inputs_node = torch.randn(args.batch_size, |
| 114 | + args.msa_length // args.tensor_model_parallel_size, |
| 115 | + args.res_length, |
| 116 | + args.cm, |
| 117 | + dtype=precision, |
| 118 | + device=torch.device("cuda")).requires_grad_(True) |
| 119 | + inputs_pair = torch.randn(args.batch_size, |
| 120 | + args.res_length // args.tensor_model_parallel_size, |
| 121 | + args.res_length, |
| 122 | + args.cz, |
| 123 | + dtype=precision, |
| 124 | + device=torch.device("cuda")).requires_grad_(True) |
| 125 | + node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length), |
| 126 | + dtype=precision, |
| 127 | + device=torch.device("cuda")).requires_grad_(False) |
| 128 | + pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length), |
| 129 | + dtype=precision, |
| 130 | + device=torch.device("cuda")).requires_grad_(False) |
| 131 | + grads_node = torch.randn_like(inputs_pair) |
| 132 | + |
| 133 | + if args.prof: |
| 134 | + prof = torch.profiler.profile( |
| 135 | + schedule=torch.profiler.schedule(wait=1, |
| 136 | + warmup=args.warmup_trials, |
| 137 | + active=args.trials, |
| 138 | + repeat=1), |
| 139 | + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'), |
| 140 | + profile_memory=False, |
| 141 | + record_shapes=False, |
| 142 | + with_stack=False) |
| 143 | + prof.start() |
| 144 | + |
| 145 | + for trial in range(0, args.trials + args.warmup_trials): |
| 146 | + layer_inputs = inputs_node, inputs_pair |
| 147 | + evt_idx = trial - args.warmup_trials |
| 148 | + |
| 149 | + torch.distributed.barrier() |
| 150 | + torch.cuda.synchronize() |
| 151 | + |
| 152 | + if evt_idx >= 0: |
| 153 | + start_evt_fwd[evt_idx].record() |
| 154 | + |
| 155 | + for lyr_idx in range(0, args.layers): |
| 156 | + layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask) |
| 157 | + |
| 158 | + torch.cuda.synchronize() |
| 159 | + |
| 160 | + if evt_idx >= 0: |
| 161 | + start_evt_bwd[evt_idx].record() |
| 162 | + |
| 163 | + if not args.fwd: |
| 164 | + layer_inputs[1].backward(grads_node) |
| 165 | + |
| 166 | + if evt_idx >= 0: |
| 167 | + stop_evt_bwd[evt_idx].record() |
| 168 | + |
| 169 | + if args.prof: |
| 170 | + prof.step() |
| 171 | + |
| 172 | + if args.prof: |
| 173 | + prof.stop() |
| 174 | + |
| 175 | + torch.distributed.barrier() |
| 176 | + torch.cuda.synchronize() |
| 177 | + elapsed_time_fwd = 0.0 |
| 178 | + elapsed_time_bwd = 0.0 |
| 179 | + for evt_idx in range(0, args.trials): |
| 180 | + elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx]) |
| 181 | + elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx]) |
| 182 | + |
| 183 | + print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format( |
| 184 | + args.batch_size, args.msa_length, args.res_length, \ |
| 185 | + args.cm, args.cz, \ |
| 186 | + elapsed_time_fwd / ( args.trials * args.layers ), \ |
| 187 | + elapsed_time_bwd / ( args.trials * args.layers ))) |
| 188 | + |
| 189 | + |
| 190 | +if __name__ == '__main__': |
| 191 | + main() |
0 commit comments