| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from enum import Enum |
| | from dataclasses import dataclass |
| | from functools import partial |
| | import numpy as np |
| | import torch |
| | from typing import Union, List |
| |
|
| |
|
| | _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" |
| |
|
| |
|
| | |
| | class ParamType(Enum): |
| | LinearWeight = partial( |
| | lambda w: w.transpose(-1, -2) |
| | ) |
| | LinearWeightMHA = partial( |
| | lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) |
| | ) |
| | LinearMHAOutputWeight = partial( |
| | lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) |
| | ) |
| | LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1)) |
| | LinearWeightOPM = partial( |
| | lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) |
| | ) |
| | Other = partial(lambda w: w) |
| |
|
| | def __init__(self, fn): |
| | self.transformation = fn |
| |
|
| |
|
| | @dataclass |
| | class Param: |
| | param: Union[torch.Tensor, List[torch.Tensor]] |
| | param_type: ParamType = ParamType.Other |
| | stacked: bool = False |
| |
|
| |
|
| | def _process_translations_dict(d, top_layer=True): |
| | flat = {} |
| | for k, v in d.items(): |
| | if type(v) == dict: |
| | prefix = _NPZ_KEY_PREFIX if top_layer else "" |
| | sub_flat = { |
| | (prefix + "/".join([k, k_prime])): v_prime |
| | for k_prime, v_prime in _process_translations_dict( |
| | v, top_layer=False |
| | ).items() |
| | } |
| | flat.update(sub_flat) |
| | else: |
| | k = "/" + k if not top_layer else k |
| | flat[k] = v |
| |
|
| | return flat |
| |
|
| |
|
| | def stacked(param_dict_list, out=None): |
| | """ |
| | Args: |
| | param_dict_list: |
| | A list of (nested) Param dicts to stack. The structure of |
| | each dict must be the identical (down to the ParamTypes of |
| | "parallel" Params). There must be at least one dict |
| | in the list. |
| | """ |
| | if out is None: |
| | out = {} |
| | template = param_dict_list[0] |
| | for k, _ in template.items(): |
| | v = [d[k] for d in param_dict_list] |
| | if type(v[0]) is dict: |
| | out[k] = {} |
| | stacked(v, out=out[k]) |
| | elif type(v[0]) is Param: |
| | stacked_param = Param( |
| | param=[param.param for param in v], |
| | param_type=v[0].param_type, |
| | stacked=True, |
| | ) |
| |
|
| | out[k] = stacked_param |
| |
|
| | return out |
| |
|
| |
|
| | def assign(translation_dict, orig_weights): |
| | for k, param in translation_dict.items(): |
| | with torch.no_grad(): |
| | weights = torch.as_tensor(orig_weights[k]) |
| | ref, param_type = param.param, param.param_type |
| | if param.stacked: |
| | weights = torch.unbind(weights, 0) |
| | else: |
| | weights = [weights] |
| | ref = [ref] |
| |
|
| | try: |
| | weights = list(map(param_type.transformation, weights)) |
| | for p, w in zip(ref, weights): |
| | p.copy_(w) |
| | except: |
| | print(k) |
| | print(ref[0].shape) |
| | print(weights[0].shape) |
| | raise |
| |
|
| |
|
| | def import_jax_weights_(model, npz_path, version="model_1"): |
| | data = np.load(npz_path) |
| |
|
| | |
| | |
| | |
| |
|
| | LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) |
| |
|
| | LinearBias = lambda l: (Param(l)) |
| |
|
| | LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) |
| |
|
| | LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) |
| |
|
| | LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) |
| |
|
| | LinearParams = lambda l: { |
| | "weights": LinearWeight(l.weight), |
| | "bias": LinearBias(l.bias), |
| | } |
| |
|
| | LayerNormParams = lambda l: { |
| | "scale": Param(l.weight), |
| | "offset": Param(l.bias), |
| | } |
| |
|
| | AttentionParams = lambda att: { |
| | "query_w": LinearWeightMHA(att.linear_q.weight), |
| | "key_w": LinearWeightMHA(att.linear_k.weight), |
| | "value_w": LinearWeightMHA(att.linear_v.weight), |
| | "output_w": Param( |
| | att.linear_o.weight, |
| | param_type=ParamType.LinearMHAOutputWeight, |
| | ), |
| | "output_b": LinearBias(att.linear_o.bias), |
| | } |
| |
|
| | AttentionGatedParams = lambda att: dict( |
| | **AttentionParams(att), |
| | **{ |
| | "gating_w": LinearWeightMHA(att.linear_g.weight), |
| | "gating_b": LinearBiasMHA(att.linear_g.bias), |
| | }, |
| | ) |
| |
|
| | GlobalAttentionParams = lambda att: dict( |
| | AttentionGatedParams(att), |
| | key_w=LinearWeight(att.linear_k.weight), |
| | value_w=LinearWeight(att.linear_v.weight), |
| | ) |
| |
|
| | TriAttParams = lambda tri_att: { |
| | "query_norm": LayerNormParams(tri_att.layer_norm), |
| | "feat_2d_weights": LinearWeight(tri_att.linear.weight), |
| | "attention": AttentionGatedParams(tri_att.mha), |
| | } |
| |
|
| | TriMulOutParams = lambda tri_mul: { |
| | "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), |
| | "left_projection": LinearParams(tri_mul.linear_a_p), |
| | "right_projection": LinearParams(tri_mul.linear_b_p), |
| | "left_gate": LinearParams(tri_mul.linear_a_g), |
| | "right_gate": LinearParams(tri_mul.linear_b_g), |
| | "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), |
| | "output_projection": LinearParams(tri_mul.linear_z), |
| | "gating_linear": LinearParams(tri_mul.linear_g), |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | TriMulInParams = lambda tri_mul: { |
| | "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), |
| | "left_projection": LinearParams(tri_mul.linear_b_p), |
| | "right_projection": LinearParams(tri_mul.linear_a_p), |
| | "left_gate": LinearParams(tri_mul.linear_b_g), |
| | "right_gate": LinearParams(tri_mul.linear_a_g), |
| | "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), |
| | "output_projection": LinearParams(tri_mul.linear_z), |
| | "gating_linear": LinearParams(tri_mul.linear_g), |
| | } |
| |
|
| | PairTransitionParams = lambda pt: { |
| | "input_layer_norm": LayerNormParams(pt.layer_norm), |
| | "transition1": LinearParams(pt.linear_1), |
| | "transition2": LinearParams(pt.linear_2), |
| | } |
| |
|
| | MSAAttParams = lambda matt: { |
| | "query_norm": LayerNormParams(matt.layer_norm_m), |
| | "attention": AttentionGatedParams(matt.mha), |
| | } |
| |
|
| | MSAColAttParams = lambda matt: { |
| | "query_norm": LayerNormParams(matt._msa_att.layer_norm_m), |
| | "attention": AttentionGatedParams(matt._msa_att.mha), |
| | } |
| |
|
| | MSAGlobalAttParams = lambda matt: { |
| | "query_norm": LayerNormParams(matt.layer_norm_m), |
| | "attention": GlobalAttentionParams(matt.global_attention), |
| | } |
| |
|
| | MSAAttPairBiasParams = lambda matt: dict( |
| | **MSAAttParams(matt), |
| | **{ |
| | "feat_2d_norm": LayerNormParams(matt.layer_norm_z), |
| | "feat_2d_weights": LinearWeight(matt.linear_z.weight), |
| | }, |
| | ) |
| |
|
| | IPAParams = lambda ipa: { |
| | "q_scalar": LinearParams(ipa.linear_q), |
| | "kv_scalar": LinearParams(ipa.linear_kv), |
| | "q_point_local": LinearParams(ipa.linear_q_points), |
| | "kv_point_local": LinearParams(ipa.linear_kv_points), |
| | "trainable_point_weights": Param( |
| | param=ipa.head_weights, param_type=ParamType.Other |
| | ), |
| | "attention_2d": LinearParams(ipa.linear_b), |
| | "output_projection": LinearParams(ipa.linear_out), |
| | } |
| |
|
| | TemplatePairBlockParams = lambda b: { |
| | "triangle_attention_starting_node": TriAttParams(b.tri_att_start), |
| | "triangle_attention_ending_node": TriAttParams(b.tri_att_end), |
| | "triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out), |
| | "triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in), |
| | "pair_transition": PairTransitionParams(b.pair_transition), |
| | } |
| |
|
| | MSATransitionParams = lambda m: { |
| | "input_layer_norm": LayerNormParams(m.layer_norm), |
| | "transition1": LinearParams(m.linear_1), |
| | "transition2": LinearParams(m.linear_2), |
| | } |
| |
|
| | OuterProductMeanParams = lambda o: { |
| | "layer_norm_input": LayerNormParams(o.layer_norm), |
| | "left_projection": LinearParams(o.linear_1), |
| | "right_projection": LinearParams(o.linear_2), |
| | "output_w": LinearWeightOPM(o.linear_out.weight), |
| | "output_b": LinearBias(o.linear_out.bias), |
| | } |
| |
|
| | def EvoformerBlockParams(b, is_extra_msa=False): |
| | if is_extra_msa: |
| | col_att_name = "msa_column_global_attention" |
| | msa_col_att_params = MSAGlobalAttParams(b.msa_att_col) |
| | else: |
| | col_att_name = "msa_column_attention" |
| | msa_col_att_params = MSAColAttParams(b.msa_att_col) |
| |
|
| | d = { |
| | "msa_row_attention_with_pair_bias": MSAAttPairBiasParams( |
| | b.msa_att_row |
| | ), |
| | col_att_name: msa_col_att_params, |
| | "msa_transition": MSATransitionParams(b.core.msa_transition), |
| | "outer_product_mean": |
| | OuterProductMeanParams(b.core.outer_product_mean), |
| | "triangle_multiplication_outgoing": |
| | TriMulOutParams(b.core.tri_mul_out), |
| | "triangle_multiplication_incoming": |
| | TriMulInParams(b.core.tri_mul_in), |
| | "triangle_attention_starting_node": |
| | TriAttParams(b.core.tri_att_start), |
| | "triangle_attention_ending_node": |
| | TriAttParams(b.core.tri_att_end), |
| | "pair_transition": |
| | PairTransitionParams(b.core.pair_transition), |
| | } |
| |
|
| | return d |
| |
|
| | ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) |
| |
|
| | FoldIterationParams = lambda sm: { |
| | "invariant_point_attention": IPAParams(sm.ipa), |
| | "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), |
| | "transition": LinearParams(sm.transition.layers[0].linear_1), |
| | "transition_1": LinearParams(sm.transition.layers[0].linear_2), |
| | "transition_2": LinearParams(sm.transition.layers[0].linear_3), |
| | "transition_layer_norm": LayerNormParams(sm.transition.layer_norm), |
| | "affine_update": LinearParams(sm.bb_update.linear), |
| | "rigid_sidechain": { |
| | "input_projection": LinearParams(sm.angle_resnet.linear_in), |
| | "input_projection_1": LinearParams(sm.angle_resnet.linear_initial), |
| | "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), |
| | "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), |
| | "resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), |
| | "resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), |
| | "unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), |
| | }, |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | tps_blocks = model.template_pair_stack.blocks |
| | tps_blocks_params = stacked( |
| | [TemplatePairBlockParams(b) for b in tps_blocks] |
| | ) |
| |
|
| | ems_blocks = model.extra_msa_stack.blocks |
| | ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) |
| |
|
| | evo_blocks = model.evoformer.blocks |
| | evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) |
| |
|
| | translations = { |
| | "evoformer": { |
| | "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), |
| | "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m), |
| | "left_single": LinearParams(model.input_embedder.linear_tf_z_i), |
| | "right_single": LinearParams(model.input_embedder.linear_tf_z_j), |
| | "prev_pos_linear": LinearParams(model.recycling_embedder.linear), |
| | "prev_msa_first_row_norm": LayerNormParams( |
| | model.recycling_embedder.layer_norm_m |
| | ), |
| | "prev_pair_norm": LayerNormParams( |
| | model.recycling_embedder.layer_norm_z |
| | ), |
| | "pair_activiations": LinearParams( |
| | model.input_embedder.linear_relpos |
| | ), |
| | "template_embedding": { |
| | "single_template_embedding": { |
| | "embedding2d": LinearParams( |
| | model.template_pair_embedder.linear |
| | ), |
| | "template_pair_stack": { |
| | "__layer_stack_no_state": tps_blocks_params, |
| | }, |
| | "output_layer_norm": LayerNormParams( |
| | model.template_pair_stack.layer_norm |
| | ), |
| | }, |
| | "attention": AttentionParams(model.template_pointwise_att.mha), |
| | }, |
| | "extra_msa_activations": LinearParams( |
| | model.extra_msa_embedder.linear |
| | ), |
| | "extra_msa_stack": ems_blocks_params, |
| | "template_single_embedding": LinearParams( |
| | model.template_angle_embedder.linear_1 |
| | ), |
| | "template_projection": LinearParams( |
| | model.template_angle_embedder.linear_2 |
| | ), |
| | "evoformer_iteration": evo_blocks_params, |
| | "single_activations": LinearParams(model.evoformer.linear), |
| | }, |
| | "structure_module": { |
| | "single_layer_norm": LayerNormParams( |
| | model.structure_module.layer_norm_s |
| | ), |
| | "initial_projection": LinearParams( |
| | model.structure_module.linear_in |
| | ), |
| | "pair_layer_norm": LayerNormParams( |
| | model.structure_module.layer_norm_z |
| | ), |
| | "fold_iteration": FoldIterationParams(model.structure_module), |
| | }, |
| | "predicted_lddt_head": { |
| | "input_layer_norm": LayerNormParams( |
| | model.aux_heads.plddt.layer_norm |
| | ), |
| | "act_0": LinearParams(model.aux_heads.plddt.linear_1), |
| | "act_1": LinearParams(model.aux_heads.plddt.linear_2), |
| | "logits": LinearParams(model.aux_heads.plddt.linear_3), |
| | }, |
| | "distogram_head": { |
| | "half_logits": LinearParams(model.aux_heads.distogram.linear), |
| | }, |
| | "experimentally_resolved_head": { |
| | "logits": LinearParams( |
| | model.aux_heads.experimentally_resolved.linear |
| | ), |
| | }, |
| | "masked_msa_head": { |
| | "logits": LinearParams(model.aux_heads.masked_msa.linear), |
| | }, |
| | } |
| |
|
| | no_templ = [ |
| | "model_3", |
| | "model_4", |
| | "model_5", |
| | "model_3_ptm", |
| | "model_4_ptm", |
| | "model_5_ptm", |
| | ] |
| | if version in no_templ: |
| | evo_dict = translations["evoformer"] |
| | keys = list(evo_dict.keys()) |
| | for k in keys: |
| | if "template_" in k: |
| | evo_dict.pop(k) |
| |
|
| | if "_ptm" in version: |
| | translations["predicted_aligned_error_head"] = { |
| | "logits": LinearParams(model.aux_heads.tm.linear) |
| | } |
| |
|
| | |
| | flat = _process_translations_dict(translations) |
| |
|
| | |
| | keys = list(data.keys()) |
| | flat_keys = list(flat.keys()) |
| | incorrect = [k for k in flat_keys if k not in keys] |
| | missing = [k for k in keys if k not in flat_keys] |
| | |
| | |
| |
|
| | assert len(incorrect) == 0 |
| | |
| |
|
| | |
| | assign(flat, data) |
| |
|