| """ | |
| Module for testing the 3D clutter filtering model. | |
| """ | |
| import os | |
| import argparse | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from utils import * | |
| from Model_ClutterFilter2D import clutter_filter_2D | |
| from DataGen import DataGen | |
| from Error_analysis import compute_mae | |
| def data_generation(in_ids_te, out_ids_te, config): | |
| DtaGenTe_prm = { | |
| 'dim': config["network_prm"]["input_dim"], | |
| 'in_dir': in_ids_te, | |
| 'out_dir': out_ids_te, | |
| 'id_list': np.arange(len(in_ids_te)), | |
| 'batch_size': config["learning_prm"]["batch_size"], | |
| 'tr_phase': False} | |
| return DataGen(**DtaGenTe_prm) | |
| def main(config): | |
| in_ids_te, out_ids_te, te_subject, val_subject = id_preparation(config) | |
| te_gen = data_generation(in_ids_te, out_ids_te, config) | |
| model = clutter_filter_2D(**config) | |
| weight_dir = create_weight_dir(val_subject, te_subject, config) | |
| model.load_weights( | |
| os.path.join(weight_dir, config["weight_name"] + ".hdf5")) | |
| results_te = model.predict_generator(te_gen, verbose=2) | |
| df_errors = compute_mae(in_ids_te, results_te) | |
| df_errors.to_csv( | |
| os.path.join(weight_dir, config["weight_name"] + ".csv")) | |
| return None | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", help="path of the config file", default="config.json") | |
| args = parser.parse_args() | |
| assert os.path.isfile(args.config) | |
| with open(args.config, "r") as read_file: | |
| config = json.load(read_file) | |
| main(config) |