| |
|
| |
|
| | import tensorflow as tf
|
| | from tensorflow.keras import layers, Model
|
| | from transformers import TFAutoModel
|
| |
|
| | class CrossEncoderTF(Model):
|
| | def __init__(self, model_name="dbmdz/bert-base-turkish-cased", max_token_len=32, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.model_name = model_name
|
| | self.max_token_len = max_token_len
|
| |
|
| |
|
| | self.bert = TFAutoModel.from_pretrained(model_name)
|
| |
|
| |
|
| | self.classifier = tf.keras.Sequential([
|
| | layers.Dense(256, activation='relu'),
|
| | layers.BatchNormalization(),
|
| | layers.Dropout(0.3),
|
| | layers.Dense(128, activation='relu'),
|
| | layers.BatchNormalization(),
|
| | layers.Dense(64, activation='relu'),
|
| | layers.BatchNormalization(),
|
| | layers.Dense(1, activation='sigmoid')
|
| | ], name="classifier")
|
| |
|
| | def call(self, inputs):
|
| | bert_output = self.bert(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
| | text_features = bert_output.pooler_output
|
| |
|
| | prediction_score = self.classifier(text_features)
|
| | return prediction_score
|
| |
|
| | def get_config(self):
|
| | config = super().get_config()
|
| | config.update({
|
| | "model_name": self.model_name,
|
| | "max_token_len": self.max_token_len,
|
| | })
|
| | return config
|
| |
|
| |
|