ngaggion's picture
Upload 171 files
1c76bf8 verified
import json
from pathlib import Path
import numpy as np
import torch
import scipy.sparse as sp
def scipy_to_torch_sparse(scp_matrix):
values = scp_matrix.data
indices = np.vstack((scp_matrix.row, scp_matrix.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = scp_matrix.shape
sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
return sparse_tensor
def load_config(dataset_path, hyperparameters = None):
config = {}
config['DATASET'] = dataset_path
with open(Path(dataset_path) / "config.json") as f:
data_config = json.load(f)
if hyperparameters is None:
hyperparameters = {}
hyperparameters['latents'] = 64
hyperparameters['initial_filters'] = 16
# Data augmentation is left to the dataset construction info
hyperparameters['flip_h'] = data_config['flip_h']
hyperparameters['flip_v'] = data_config['flip_v']
hyperparameters['transpose'] = data_config['transpose']
hyperparameters['rotate'] = data_config['rotate']
config.update({
'organs': data_config['organs'],
'organ_names': data_config['organ_names'],
'resolutions': data_config['resolutions'],
'inputsize': data_config['inputsize'],
'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
})
config.update(hyperparameters)
num_pooling = len(config['resolutions'])
config['filters'] = [config['initial_filters'] * (i+1) // 2 for i in range(num_pooling+1)]
config['filters'] = [x for x in config['filters'] for _ in (0, 1)]
config['filters'][0] = 2
if config['naive']:
adj_path = "Naive"
else:
adj_path = "NonNaive"
A_ = []
for res in config['resolutions']:
A = np.load(Path(dataset_path) / adj_path / f"adj_{res}_block_diagonal.npy")
A = sp.csc_matrix(A).tocoo()
A_.extend([A.copy()])
A_.append(A_[-1])
A_t = [scipy_to_torch_sparse(x).to(config['device']) for x in A_]
config['n_nodes'] = [A.shape[0] for A in A_]
D_ = []
for res in ['to_' + x for x in config['resolutions'][1:]]:
D = np.load(Path(dataset_path) / adj_path / f"downsampling_{res}.npy")
D_.append(sp.csc_matrix(D).tocoo())
D_t = [scipy_to_torch_sparse(x).to(config['device']) for x in D_]
U_ = []
for res in ['to_' + x for x in config['resolutions'][:-1]]:
U = np.load(Path(dataset_path) / adj_path / f"upsampling_{res}.npy")
U_.append(sp.csc_matrix(U).tocoo())
U_t = [scipy_to_torch_sparse(x).to(config['device']) for x in U_]
return config, D_t, U_t, A_t