diff --git a/fastfold/model/fastnn/embedders.py b/fastfold/model/fastnn/embedders.py index db91f04b..c31888a4 100644 --- a/fastfold/model/fastnn/embedders.py +++ b/fastfold/model/fastnn/embedders.py @@ -185,7 +185,8 @@ def forward(self, template_angle_feat = build_template_angle_feat( single_template_feats, ) - + template_angle_feat = template_angle_feat.to(dtype=z.dtype) + # [*, S_t, N, C_m] a = self.template_angle_embedder(template_angle_feat)