| from transformers import PretrainedConfig, AutoConfig
|
|
|
| class MultiTaskConfig(PretrainedConfig):
|
| model_type = "barthez-multitask"
|
|
|
| def __init__(
|
| self,
|
| base_model_name="moussaKam/barthez",
|
| num_labels_type=3,
|
| num_labels_priorite=3,
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
|
|
| self.base_model_name = base_model_name
|
| self.num_labels_type = num_labels_type
|
| self.num_labels_priorite = num_labels_priorite
|
|
|
|
|
| self.base_config = AutoConfig.from_pretrained(base_model_name)
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import BartPreTrainedModel, BartModel
|
| from dataclasses import dataclass
|
| from transformers.modeling_outputs import ModelOutput
|
| from typing import Optional
|
|
|
|
|
| @dataclass
|
| class MultiTaskOutput(ModelOutput):
|
| loss: Optional[torch.FloatTensor] = None
|
| logits_type: torch.FloatTensor = None
|
| logits_priorite: torch.FloatTensor = None
|
|
|
|
|
| class MultiTaskModel(BartPreTrainedModel):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
| self.model = BartModel(config)
|
| hidden_size = config.d_model
|
|
|
| self.classifier_type = nn.Linear(hidden_size, config.num_labels_type)
|
| self.classifier_priorite = nn.Linear(hidden_size, config.num_labels_priorite)
|
|
|
| self.loss_fct = nn.CrossEntropyLoss()
|
|
|
| self.post_init()
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| labels_type=None,
|
| labels_priorite=None,
|
| ):
|
| outputs = self.model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| )
|
|
|
| pooled_output = outputs.last_hidden_state.mean(dim=1)
|
|
|
| logits_type = self.classifier_type(pooled_output)
|
| logits_priorite = self.classifier_priorite(pooled_output)
|
|
|
| loss = None
|
| if labels_type is not None and labels_priorite is not None:
|
| loss_type = self.loss_fct(logits_type, labels_type)
|
| loss_priorite = self.loss_fct(logits_priorite, labels_priorite)
|
| loss = loss_type + loss_priorite
|
|
|
| return MultiTaskOutput(
|
| loss=loss,
|
| logits_type=logits_type,
|
| logits_priorite=logits_priorite,
|
| ) |