| | """Script to download the pre-trained tensorflow weights and convert them to pytorch weights.""" |
| | import os |
| | import argparse |
| | import torch |
| | import numpy as np |
| | from tensorflow.python.training import py_checkpoint_reader |
| |
|
| | from repnet import utils |
| | from repnet.model import RepNet |
| |
|
| |
|
| | |
| | PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| | TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt' |
| | TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index'] |
| | OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints') |
| |
|
| | |
| | WEIGHTS_PERMUTATION = { |
| | 2: (1, 0), |
| | 4: (3, 2, 0, 1), |
| | 5: (4, 3, 0, 1, 2) |
| | } |
| |
|
| | |
| | ATTR_MAPPING = { |
| | 'kernel':'weight', |
| | 'bias': 'bias', |
| | 'beta': 'bias', |
| | 'gamma': 'weight', |
| | 'moving_mean': 'running_mean', |
| | 'moving_variance': 'running_var' |
| | } |
| |
|
| | |
| | WEIGHTS_MAPPING = [ |
| | |
| | ('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'), |
| | ('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'), |
| | ('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'), |
| | ('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'), |
| | ('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'), |
| | ('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'), |
| | ('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'), |
| | ('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'), |
| | ('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'), |
| | ('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'), |
| | ('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'), |
| | ('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'), |
| | ('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'), |
| | ('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'), |
| | ('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'), |
| | ('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'), |
| | ('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'), |
| | ('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'), |
| | ('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'), |
| | ('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'), |
| | ('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'), |
| | ('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'), |
| | ('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'), |
| | ('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'), |
| | ('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'), |
| | ('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'), |
| | ('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'), |
| | ('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'), |
| | ('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'), |
| | ('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'), |
| | ('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'), |
| | ('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'), |
| | ('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'), |
| | ('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'), |
| | ('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'), |
| | ('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'), |
| | ('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'), |
| | ('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'), |
| | ('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'), |
| | ('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'), |
| | ('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'), |
| | ('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'), |
| | ('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'), |
| | ('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'), |
| | ('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'), |
| | ('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'), |
| | ('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'), |
| | ('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'), |
| | ('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'), |
| | ('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'), |
| | ('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'), |
| | ('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'), |
| | ('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'), |
| | ('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'), |
| | ('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'), |
| | ('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'), |
| | ('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'), |
| | ('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'), |
| | ('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'), |
| | ('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'), |
| | ('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'), |
| | ('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'), |
| | ('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'), |
| | ('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'), |
| | |
| | ('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'), |
| | ('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'), |
| | ('conv_3x3_layer', 'conv2d', 'tsm_conv.0'), |
| | |
| | ('input_projection', 'dense', 'period_length_head.0.input_projection'), |
| | ('pos_encoding', None, 'period_length_head.0.pos_encoding'), |
| | ('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'), |
| | ('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'), |
| | ('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'), |
| | ('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'), |
| | ('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'), |
| | ('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'), |
| | ('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'), |
| | ('fc_layers.0', 'dense_14', 'period_length_head.1'), |
| | ('fc_layers.1', 'dense_15', 'period_length_head.3'), |
| | ('fc_layers.2', 'dense_16', 'period_length_head.5'), |
| | |
| | ('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'), |
| | ('pos_encoding2', None, 'periodicity_head.0.pos_encoding'), |
| | ('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'), |
| | ('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'), |
| | ('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'), |
| | ('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'), |
| | ('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'), |
| | ('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'), |
| | ('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'), |
| | ('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'), |
| | ('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'), |
| | ('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'), |
| | ] |
| |
|
| | |
| | parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = parser.parse_args() |
| |
|
| | |
| | print('Downloading checkpoints...') |
| | tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint') |
| | os.makedirs(tf_checkpoint_dir, exist_ok=True) |
| | for file in TF_CHECKPOINT_FILES: |
| | dst = os.path.join(tf_checkpoint_dir, file) |
| | if not os.path.exists(dst): |
| | utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst) |
| |
|
| | |
| | print('Loading tensorflow checkpoint...') |
| | checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88') |
| | checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) |
| | shape_map = checkpoint_reader.get_variable_to_shape_map() |
| | tf_state_dict = {} |
| | for var_name in sorted(shape_map.keys()): |
| | var_tensor = checkpoint_reader.get_tensor(var_name) |
| | if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name: |
| | continue |
| | |
| | var_path = var_name.split('/')[1:] |
| | var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']] |
| | |
| | current_dict = tf_state_dict |
| | for path in var_path[:-1]: |
| | current_dict = current_dict.setdefault(path, {}) |
| | current_dict[var_path[-1]] = var_tensor |
| |
|
| | |
| | for k in ['transformer_layers', 'transformer_layers2']: |
| | v = tf_state_dict[k]['0']['mha'] |
| | v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0) |
| | v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0) |
| | del v['wk'], v['wq'], v['wv'] |
| | tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True) |
| | |
| | for k, v in tf_state_dict.items(): |
| | if not isinstance(v, dict): |
| | tf_state_dict[k] = {None: v} |
| |
|
| | |
| | print(f'Converting to PyTorch format...') |
| | pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth') |
| | pt_state_dict = {} |
| | for k_tf, _, k_pt in WEIGHTS_MAPPING: |
| | assert k_pt not in pt_state_dict |
| | pt_state_dict[k_pt] = {} |
| | for attr in tf_state_dict[k_tf]: |
| | new_attr = ATTR_MAPPING.get(attr, attr) |
| | pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr]) |
| | if attr == 'kernel': |
| | weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] |
| | pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation) |
| | pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True) |
| | torch.save(pt_state_dict, pt_checkpoint_path) |
| |
|
| | |
| | print('Check that the weights can be loaded into the model...') |
| | model = RepNet() |
| | pt_state_dict = torch.load(pt_checkpoint_path) |
| | model.load_state_dict(pt_state_dict) |
| |
|
| | print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.') |
| |
|