| """ | |
| Module for training the 2D clutter filtering model with L2 loss. | |
| """ | |
| import os | |
| import argparse | |
| import json | |
| import numpy as np | |
| from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau | |
| from utils import * | |
| from Model_ClutterFilter2D import clutter_filter_2D | |
| from DataGen import DataGen | |
| def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config): | |
| DtaGenTr_prm = { | |
| 'dim': config["network_prm"]["input_dim"], | |
| 'in_dir': in_ids_tr, | |
| 'out_dir': out_ids_tr, | |
| 'id_list': np.arange(len(in_ids_tr)), | |
| 'batch_size': config["learning_prm"]["batch_size"], | |
| 'tr_phase': True} | |
| DtaGenVal_prm = { | |
| 'dim': config["network_prm"]["input_dim"], | |
| 'in_dir': in_ids_val, | |
| 'out_dir': out_ids_val, | |
| 'id_list': np.arange(len(in_ids_val)), | |
| 'batch_size': config["learning_prm"]["batch_size"], | |
| 'tr_phase': True} | |
| tr_gen = DataGen(**DtaGenTr_prm) | |
| val_gen = DataGen(**DtaGenVal_prm) | |
| return tr_gen, val_gen | |
| def model_chkpnt(val_subject, te_subject, weight_dir, config): | |
| weight_name = ( | |
| f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}' | |
| f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}' | |
| f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}') | |
| filepath = (weight_dir + '/'+ weight_name + | |
| '_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" + | |
| '_valloss' + "{val_loss:.5f}" + ".hdf5") | |
| model_checkpoint = ModelCheckpoint(filepath=filepath, | |
| monitor="val_loss", | |
| verbose=0, | |
| save_best_only=True) | |
| reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, | |
| patience=4, min_lr=1e-7) | |
| return model_checkpoint, reduce_lr | |
| def main(config): | |
| in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config) | |
| weight_dir = create_weight_dir(val_subject, te_subject, config) | |
| tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config) | |
| model = clutter_filter_2D(**config) | |
| model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config) | |
| model.fit(tr_gen, | |
| validation_data=val_gen, | |
| epochs=config["learning_prm"]["n_epochs"], | |
| verbose=1, | |
| callbacks=[model_checkpoint, reduce_lr]) | |
| return None | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", help="path of the config file", default="config.json") | |
| args = parser.parse_args() | |
| assert os.path.isfile(args.config) | |
| with open(args.config, "r") as read_file: | |
| config = json.load(read_file) | |
| main(config) |