| |
|
| | from cnnClassfier.config.configuration import PrepareCallbacksConfig |
| | import time |
| | import os |
| | import tensorflow as tf |
| |
|
| |
|
| | class PrepareCallback: |
| | def __init__(self, config: PrepareCallbacksConfig): |
| | self.config = config |
| | |
| | @property |
| | def _create_tb_callbacks(self): |
| | timestamp = time.strftime('%Y-%m-%d-%H-%M-%S') |
| | |
| | tb_running_log_dir = os.path.join( |
| | str(self.config.tensorboard_root_log_dir), |
| | f"tb_logs_at_{timestamp}", |
| | ) |
| | return tf.keras.callbacks.TensorBoard(log_dir=tb_running_log_dir) |
| |
|
| | @property |
| | def _create_ckpt_callbacks(self): |
| | return tf.keras.callbacks.ModelCheckpoint( |
| | filepath=str(self.config.checkpoint_model_filepath), |
| | save_best_only=True |
| | ) |
| |
|
| | def get_tb_callbacks(self): |
| | return [ |
| | self._create_tb_callbacks, |
| | self._create_ckpt_callbacks |
| | ] |