| from transformers import PretrainedConfig |
| from src.regression.PL import EncoderPL, DecoderPL |
| from typing import List |
|
|
|
|
| class FullModelConfigHF(PretrainedConfig): |
|
|
| model_type = "full_model" |
|
|
| def __init__( |
| self, |
| tokenizer_ckpt: str = "", |
| bert_ckpt: str = "", |
| decoder_ckpt: str = "", |
| layer_norm: bool = True, |
| nontext_features: List[str] = ["aov"], |
| **kwargs, |
| ): |
|
|
| self.tokenizer_ckpt = tokenizer_ckpt |
| self.bert_ckpt = bert_ckpt |
| self.decoder_ckpt = decoder_ckpt |
| self.nontext_features = nontext_features |
| self.layer_norm = layer_norm |
| super().__init__(**kwargs) |
|
|