|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| import os |
| import numpy as np |
| import click |
| import pandas as pd |
|
|
| from network import mnist_net_my as mnist_net |
| from network import adaptor_v2 |
| from tools import causalaugment_v3 as causalaugment |
| from main_my_joint_v13_auto import evaluate,evaluate_causal,evaluate_causal_with_entropy,evaluate_mapping,evaluate_causal_with_average |
| import data_loader_joint_v3 as data_loader |
|
|
| @click.command() |
| @click.option('--gpu', type=str, default='0', help='选择GPU编号') |
| @click.option('--svroot', type=str, default='./saved') |
| @click.option('--svpath', type=str, default=None, help='保存日志的路径') |
| @click.option('--channels', type=int, default=3) |
| @click.option('--factor_num', type=int, default=16) |
| @click.option('--stride', type=int, default=16) |
| @click.option('--epoch', type=str, default='best') |
| @click.option('--eval_mapping', type=bool, default=True, help='是否查看mapping学习效果') |
| def main(gpu, svroot, svpath, channels, factor_num,stride, epoch, eval_mapping): |
| evaluate_digit(gpu, svroot, svpath, channels, factor_num, stride,epoch, eval_mapping) |
| |
| def evaluate_digit(gpu, svroot, svpath, channels=3, factor_num=16,stride=5,epoch='best', eval_mapping=True): |
| settings = locals().copy() |
| print(settings) |
| os.environ['CUDA_VISIBLE_DEVICES'] = gpu |
|
|
| |
| if channels == 3: |
| cls_net = mnist_net.ConvNet().cuda() |
| elif channels == 1: |
| cls_net = mnist_net.ConvNet(imdim=channels).cuda() |
| if epoch == 'best': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'best_cls_net.pkl')) |
| elif epoch == 'last': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'last_cls_net.pkl')) |
| cls_net.load_state_dict(saved_weight) |
| |
| |
| FA = causalaugment.FactualAugment(m=4, factor_num=factor_num) |
| CA = causalaugment.MultiCounterfactualAugment(factor_num,stride) |
| |
| |
| |
| AdaptNet = [] |
| parameter_list = [] |
| for i in range(factor_num): |
| if epoch == 'best': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl')) |
| elif epoch == 'last': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'last_mapping_'+str(i)+'.pkl')) |
| |
| mapping = adaptor_v2.mapping(1024,512,1024,2).cuda() |
| mapping.load_state_dict(saved_weight) |
| AdaptNet.append(mapping) |
| if epoch == 'best': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl')) |
| elif epoch == 'last': |
| print("loading weight of %s"%(epoch)) |
| saved_weight = torch.load(os.path.join(svroot, 'last_E_to_W.pkl')) |
|
|
| E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
| |
| |
| |
| |
| E_to_W.load_state_dict(saved_weight) |
|
|
| |
| str2fun = { |
| 'mnist': data_loader.load_mnist, |
| 'mnist_m': data_loader.load_mnist_m, |
| 'usps': data_loader.load_usps, |
| 'svhn': data_loader.load_svhn, |
| 'syndigit': data_loader.load_syndigit, |
| } |
| columns = ['mnist', 'svhn', 'mnist_m', 'syndigit','usps'] |
| target = ['svhn', 'mnist_m', 'syndigit','usps'] |
| if eval_mapping: |
| index = FA.factor_list |
| index.append('w/o do (original x)') |
| else: |
| index = ['w/o do (original x)'] |
| index_ours = ['do'] |
| data_result = {} |
| data_result_ours = {} |
| cls_net.eval() |
| for idx, data in enumerate(columns): |
| teset = str2fun[data]('test', channels=channels) |
| teloader = DataLoader(teset, batch_size=8, num_workers=0) |
| |
| acc_CA = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W) |
| data_result_ours[data] = acc_CA |
| |
| if eval_mapping: |
| if data == 'mnist': |
| teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=True) |
| acc_avg = np.zeros(teacc_FA_aftermapping.shape) |
| acc_avg_CA = np.zeros(acc_CA.shape) |
| else: |
| teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=False) |
| acc_avg = acc_avg + teacc_FA_aftermapping |
| acc_avg_CA = acc_avg_CA + acc_CA |
| data_result[data]=teacc_FA_aftermapping |
| data_result[data+'_FA'] = acc_FA |
| else: |
| teacc = evaluate(cls_net, teloader) |
| if data == 'mnist': |
| acc_avg = np.zeros(teacc.shape) |
| acc_avg_CA = np.zeros(acc_CA.shape) |
| else: |
| acc_avg = acc_avg + teacc |
| acc_avg_CA = acc_avg_CA + acc_CA |
| data_result[data] = teacc |
| acc_avg = acc_avg/float(len(target)) |
| acc_avg_CA = acc_avg_CA/float(len(target)) |
| |
| data_result['Avg'] = acc_avg |
| data_result_ours['Avg'] = acc_avg_CA |
|
|
| df = pd.DataFrame(data_result,index = index) |
| df_ours = pd.DataFrame(data_result_ours,index = index_ours) |
| print(df) |
| print(df_ours) |
| if svpath is not None: |
| df.to_csv(svpath) |
| df_ours.to_csv(svpath, mode='a') |
|
|
| if __name__=='__main__': |
| main() |
|
|
|
|