| from dataclasses import dataclass
|
| from typing import Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoConfig, AutoModel, PreTrainedModel
|
| from transformers.modeling_outputs import ModelOutput
|
|
|
| from .configuration_suave_multitask import SuaveMultitaskConfig
|
|
|
|
|
| @dataclass
|
| class SuaveMultitaskOutput(ModelOutput):
|
| loss: Optional[torch.FloatTensor] = None
|
| logits_binary: Optional[torch.FloatTensor] = None
|
| logits_multiclass: Optional[torch.FloatTensor] = None
|
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
|
|
|
|
| class SuaveMultitaskModel(PreTrainedModel):
|
| config_class = SuaveMultitaskConfig
|
| base_model_prefix = "encoder"
|
|
|
| def __init__(self, config: SuaveMultitaskConfig):
|
| super().__init__(config)
|
| base_config = AutoConfig.from_pretrained(config.base_model_name)
|
| self.encoder = AutoModel.from_config(base_config)
|
| hidden_size = self.encoder.config.hidden_size
|
|
|
| self.dropout = nn.Dropout(config.classifier_dropout)
|
| self.classifier_binary = nn.Linear(hidden_size, 2)
|
| self.classifier_multiclass = nn.Linear(hidden_size, config.num_ai_classes)
|
|
|
| self.post_init()
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| labels_binary=None,
|
| labels_multiclass=None,
|
| **kwargs,
|
| ):
|
| outputs = self.encoder(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| output_hidden_states=kwargs.get("output_hidden_states", False),
|
| output_attentions=kwargs.get("output_attentions", False),
|
| )
|
|
|
| pooled = outputs.last_hidden_state[:, 0]
|
| pooled = self.dropout(pooled)
|
|
|
| logits_binary = self.classifier_binary(pooled)
|
| logits_multiclass = self.classifier_multiclass(pooled)
|
|
|
| loss = None
|
| if labels_binary is not None and labels_multiclass is not None:
|
| loss_binary = nn.CrossEntropyLoss()(logits_binary, labels_binary)
|
| loss_multiclass = nn.CrossEntropyLoss(ignore_index=-1)(
|
| logits_multiclass, labels_multiclass
|
| )
|
| loss = loss_binary + 0.5 * loss_multiclass
|
|
|
| return SuaveMultitaskOutput(
|
| loss=loss,
|
| logits_binary=logits_binary,
|
| logits_multiclass=logits_multiclass,
|
| hidden_states=outputs.hidden_states,
|
| attentions=outputs.attentions,
|
| )
|
|
|