File size: 2,667 Bytes
1c76bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
from pathlib import Path
import numpy as np
import torch
import scipy.sparse as sp

def scipy_to_torch_sparse(scp_matrix):
    values = scp_matrix.data
    indices = np.vstack((scp_matrix.row, scp_matrix.col))
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    shape = scp_matrix.shape

    sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
    return sparse_tensor

def load_config(dataset_path, hyperparameters = None):
    config = {}
    config['DATASET'] = dataset_path
    
    with open(Path(dataset_path) / "config.json") as f:
        data_config = json.load(f)

    if hyperparameters is None:
        hyperparameters = {}
        hyperparameters['latents'] = 64
        hyperparameters['initial_filters'] = 16
    
    # Data augmentation is left to the dataset construction info
    hyperparameters['flip_h'] = data_config['flip_h']
    hyperparameters['flip_v'] = data_config['flip_v']
    hyperparameters['transpose'] = data_config['transpose']
    hyperparameters['rotate'] = data_config['rotate'] 
    
    config.update({
        'organs': data_config['organs'],
        'organ_names': data_config['organ_names'],
        'resolutions': data_config['resolutions'],
        'inputsize': data_config['inputsize'],
        'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    })
    
    config.update(hyperparameters)

    num_pooling = len(config['resolutions'])
    config['filters'] = [config['initial_filters'] * (i+1) // 2 for i in range(num_pooling+1)]
    config['filters'] = [x for x in config['filters'] for _ in (0, 1)]
    config['filters'][0] = 2
    
    if config['naive']:
        adj_path = "Naive"
    else:
        adj_path = "NonNaive"
        
    A_ = []
    for res in config['resolutions']:
        A = np.load(Path(dataset_path) / adj_path / f"adj_{res}_block_diagonal.npy")
        A = sp.csc_matrix(A).tocoo()
        A_.extend([A.copy()])
    A_.append(A_[-1])
    A_t = [scipy_to_torch_sparse(x).to(config['device']) for x in A_]
    config['n_nodes'] = [A.shape[0] for A in A_]

    D_ = []
    for res in ['to_' + x for x in config['resolutions'][1:]]:
        D = np.load(Path(dataset_path) / adj_path /  f"downsampling_{res}.npy")
        D_.append(sp.csc_matrix(D).tocoo())
    D_t = [scipy_to_torch_sparse(x).to(config['device']) for x in D_]

    U_ = []
    for res in ['to_' + x for x in config['resolutions'][:-1]]:
        U = np.load(Path(dataset_path) / adj_path /  f"upsampling_{res}.npy")
        U_.append(sp.csc_matrix(U).tocoo())
    U_t = [scipy_to_torch_sparse(x).to(config['device']) for x in U_]

    return config, D_t, U_t, A_t