Irwiny123's picture
提交Uni-Core初始代码
eb6d243
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from unicore import utils
from unicore.models import BaseUnicoreModel, register_model, register_model_architecture
from unicore.modules import LayerNorm, TransformerEncoder, init_bert_params
logger = logging.getLogger(__name__)
@register_model("bert")
class BertModel(BaseUnicoreModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--encoder-layers", type=int, metavar="L", help="num encoder layers"
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="H",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="F",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="A",
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use for pooler layer",
)
parser.add_argument(
"--emb-dropout", type=float, metavar="D", help="dropout probability for embeddings"
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--pooler-dropout",
type=float,
metavar="D",
help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
"--max-seq-len", type=int, help="number of positional embeddings to learn"
)
parser.add_argument(
"--post-ln", type=bool, help="use post layernorm or pre layernorm"
)
def __init__(self, args, dictionary):
super().__init__()
base_architecture(args)
self.args = args
self.padding_idx = dictionary.pad()
self.embed_tokens = nn.Embedding(len(dictionary), args.encoder_embed_dim, self.padding_idx)
self.embed_positions = nn.Embedding(args.max_seq_len, args.encoder_embed_dim)
self.sentence_encoder = TransformerEncoder(
encoder_layers=args.encoder_layers,
embed_dim=args.encoder_embed_dim,
ffn_embed_dim=args.encoder_ffn_embed_dim,
attention_heads=args.encoder_attention_heads,
emb_dropout=args.emb_dropout,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=args.max_seq_len,
activation_fn=args.activation_fn,
rel_pos=True,
rel_pos_bins=32,
max_rel_pos=128,
post_ln=args.post_ln,
)
self.lm_head = BertLMHead(embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.embed_tokens.weight,
)
self.classification_heads = nn.ModuleDict()
self.apply(init_bert_params)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
return cls(args, task.dictionary)
def forward(
self,
src_tokens,
masked_tokens,
features_only=False,
classification_head_name=None,
**kwargs
):
if classification_head_name is not None:
features_only = True
padding_mask = src_tokens.eq(self.padding_idx)
if not padding_mask.any():
padding_mask = None
x = self.embed_tokens(src_tokens)
x += self.embed_positions.weight[:src_tokens.size(1), :]
x = self.sentence_encoder(x, padding_mask=padding_mask)
if not features_only:
x = self.lm_head(x, masked_tokens)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = BertClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
)
class BertLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the masked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class BertClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes), q_noise, qn_block_size
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@register_model_architecture("bert", "bert")
def base_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.dropout = getattr(args, "dropout", 0.1)
args.emb_dropout = getattr(args, "emb_dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.max_seq_len = getattr(args, "max_seq_len", 512)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.post_ln = getattr(args, "post_ln", True)
@register_model_architecture("bert", "bert_base")
def bert_base_architecture(args):
base_architecture(args)
@register_model_architecture("bert", "bert_large")
def bert_large_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
@register_model_architecture("bert", "xlm")
def xlm_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)