| import torch |
| from torch import nn |
|
|
| from uniperceiver.config import configurable |
| from ..layers.transformer_encoder_layer import TransformerEncoderLayer |
| from ..layers.transformer_encoder_moe_layer import MoETransformerEncoderLayer |
| from .build import ENCODER_REGISTRY |
| import uniperceiver.utils.comm as comm |
|
|
|
|
|
|
| __all__ = ["UnifiedBertEncoder"] |
|
|
| def _construct_attention_masks( data, sample_info, task_info): |
| mask_type = torch.bool |
| device = data.device |
|
|
| attn_mask = None |
| if isinstance(sample_info, list): |
| sample_info = sample_info[0] |
| if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False): |
|
|
| |
| spe_length, img_length, text_total_length = sample_info['data_length'] |
| text_length = text_total_length//2 |
|
|
| attn_mask = torch.ones((spe_length + img_length + text_total_length, |
| spe_length + img_length + text_total_length), dtype=mask_type, device=device) |
|
|
| attn_mask[:spe_length + img_length + text_total_length, :spe_length+img_length] = False |
| attn_mask[spe_length + img_length:spe_length + img_length + text_length, spe_length + img_length:spe_length + img_length + text_length] = torch.ones( |
| (text_length, text_length), dtype=mask_type, device=device).triu_(diagonal=1) |
| attn_mask[spe_length + img_length + text_length:, spe_length + img_length:spe_length + img_length + text_length] = torch.ones( |
| (text_length, text_length), |
| dtype=mask_type, |
| device=device).triu_(diagonal=0) |
| attn_mask[spe_length + img_length + text_length:, |
| spe_length + img_length + text_length:] = ~torch.ones( |
| (text_length), dtype=mask_type, |
| device=device).diag() |
|
|
| return attn_mask |
|
|
|
|
| @ENCODER_REGISTRY.register() |
| class UnifiedBertEncoder(nn.Module): |
| @configurable |
| def __init__(self, *, num_hidden_layers: int, bert_layers, |
| skip_target_encode, word_balance_losses, |
| bookswiki_word_alone, cfg): |
| super(UnifiedBertEncoder, self).__init__() |
| self.num_hidden_layers = num_hidden_layers |
| self.layers = bert_layers |
| self.skip_target_encode = skip_target_encode |
| self.word_balance_losses = word_balance_losses |
| self.bookswiki_word_alone = bookswiki_word_alone |
| self.cfg = cfg |
|
|
|
|
|
|
| @classmethod |
| def from_config(cls, cfg): |
| if cfg.MODEL.BERT.DROP_PATH_PROB_FIXED: |
| dpr = [cfg.MODEL.BERT.DROP_PATH_PROB for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)] |
| else: |
| dpr = [x.item() for x in torch.linspace(0, cfg.MODEL.BERT.DROP_PATH_PROB, cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)] |
|
|
| layers = [] |
| for layer_idx in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS): |
| if not cfg.MOE.MOE: |
| layers.append( |
| TransformerEncoderLayer( |
| d_model=cfg.MODEL.BERT.HIDDEN_SIZE, |
| nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS, |
| dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE, |
| dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB, |
| drop_path_ratio=dpr[layer_idx], |
| activation=cfg.MODEL.BERT.HIDDEN_ACT, |
| layer_scale=cfg.MODEL.LAYER_SCALE, |
| ls_init_values=cfg.MODEL.LAYER_SCALE_INIT, |
| batch_first=True, |
| norm_first=True, |
| cfg = cfg, |
| )) |
| else: |
| attention_moe = False |
| ffn_moe = False |
|
|
| moe_layer_start_idx = cfg.MOE.MOE_LAYER_START_IDX |
| moe_layer_end_idx = cfg.MOE.MOE_LAYER_END_IDX |
|
|
| if cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'odd': |
| if layer_idx % 2 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx: |
| moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',') |
| attention_moe = "SA" in moe_layers |
| ffn_moe = 'FFN' in moe_layers |
|
|
| elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'four': |
| if layer_idx % 4 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx: |
| moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',') |
| attention_moe = "SA" in moe_layers |
| ffn_moe = 'FFN' in moe_layers |
|
|
| elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'all': |
| if layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx: |
| moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',') |
| attention_moe = "SA" in moe_layers |
| ffn_moe = 'FFN' in moe_layers |
| elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'none': |
| attention_moe = None |
| ffn_moe = None |
|
|
|
|
| elif cfg.MOE.MOE: |
| raise NotImplementedError('cfg.MOE.MOE_EXPERT_LOCATION') |
|
|
| layers.append( |
| MoETransformerEncoderLayer( |
| d_model=cfg.MODEL.BERT.HIDDEN_SIZE, |
| nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS, |
| dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE, |
| dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB, |
| drop_path_ratio=dpr[layer_idx], |
| activation=cfg.MODEL.BERT.HIDDEN_ACT, |
| layer_scale=cfg.MODEL.LAYER_SCALE, |
| ls_init_values=cfg.MODEL.LAYER_SCALE_INIT, |
| batch_first=False, |
| norm_first=True, |
| cfg = cfg, |
| ffn_moe=ffn_moe, |
| attn_moe=attention_moe, |
| )) |
|
|
|
|
|
|
| bert_layers = nn.ModuleList( |
| layers |
| ) |
| return { |
| "num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS, |
| "skip_target_encode": cfg.MODEL.BERT.SKIP_TARGET_ENCODE, |
| "bert_layers": bert_layers, |
| "word_balance_losses": cfg.SOLVER.WORD_BALANCE_LOSSESS, |
| "bookswiki_word_alone": cfg.MODEL.BW_WORD_ALONE, |
| "cfg": cfg |
| } |
|
|
| @classmethod |
| def add_config(cls, cfg): |
| pass |
|
|
|
|
| def forward(self, data, invalid_mask, sample_info, task_info, history_states=None, return_all=False, **kwargs): |
|
|
| attn_mask = _construct_attention_masks(data, sample_info, task_info) |
| kwargs.update({'sample_info': sample_info}) |
| data_type = kwargs.get('data_type', 'input') |
| if data_type == 'target' and self.skip_target_encode: |
| |
| return data |
| if return_all: |
| data_all = [data] |
| for l, layer_module in enumerate(self.layers): |
|
|
| if history_states is None: |
| data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, task_info=task_info, **kwargs) |
| else: |
| data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, history_states=history_states[l], task_info=task_info, **kwargs) |
|
|
| if return_all: |
| data_all.append(data) |
|
|
| return data if not return_all else data_all |
|
|