Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +45 -0
- results/versatile_diffusion/subj01/951.png +3 -0
- results/versatile_diffusion/subj01/952.png +3 -0
- results/versatile_diffusion/subj01/953.png +3 -0
- results/versatile_diffusion/subj01/955.png +3 -0
- results/versatile_diffusion/subj01/957.png +3 -0
- results/versatile_diffusion/subj01/958.png +3 -0
- results/versatile_diffusion/subj01/959.png +3 -0
- results/versatile_diffusion/subj01/96.png +3 -0
- results/versatile_diffusion/subj01/960.png +3 -0
- results/versatile_diffusion/subj01/961.png +3 -0
- results/versatile_diffusion/subj01/962.png +3 -0
- results/versatile_diffusion/subj01/963.png +3 -0
- results/versatile_diffusion/subj01/964.png +3 -0
- results/versatile_diffusion/subj01/965.png +3 -0
- results/versatile_diffusion/subj01/966.png +3 -0
- results/versatile_diffusion/subj01/967.png +3 -0
- results/versatile_diffusion/subj01/968.png +3 -0
- results/versatile_diffusion/subj01/969.png +3 -0
- results/versatile_diffusion/subj01/97.png +3 -0
- results/versatile_diffusion/subj01/970.png +3 -0
- results/versatile_diffusion/subj01/971.png +3 -0
- results/versatile_diffusion/subj01/972.png +3 -0
- results/versatile_diffusion/subj01/973.png +3 -0
- results/versatile_diffusion/subj01/974.png +3 -0
- results/versatile_diffusion/subj01/975.png +3 -0
- results/versatile_diffusion/subj01/976.png +3 -0
- results/versatile_diffusion/subj01/977.png +3 -0
- results/versatile_diffusion/subj01/978.png +3 -0
- results/versatile_diffusion/subj01/979.png +3 -0
- results/versatile_diffusion/subj01/98.png +3 -0
- results/versatile_diffusion/subj01/980.png +3 -0
- results/versatile_diffusion/subj01/981.png +3 -0
- results/versatile_diffusion/subj01/99.png +3 -0
- results/versatile_diffusion/subj01/roi/0.png +3 -0
- results/versatile_diffusion/subj01/roi/1.png +3 -0
- results/versatile_diffusion/subj01/roi/10.png +3 -0
- results/versatile_diffusion/subj01/roi/11.png +3 -0
- results/versatile_diffusion/subj01/roi/12.png +3 -0
- results/versatile_diffusion/subj01/roi/2.png +3 -0
- results/versatile_diffusion/subj01/roi/3.png +3 -0
- results/versatile_diffusion/subj01/roi/5.png +3 -0
- results/versatile_diffusion/subj01/roi/6.png +3 -0
- results/versatile_diffusion/subj01/roi/7.png +3 -0
- results/versatile_diffusion/subj01/roi/8.png +3 -0
- results/versatile_diffusion/subj01/roi/9.png +3 -0
- scripts/clipvision_extract_features.py +88 -0
- scripts/clipvision_regression.py +71 -0
- scripts/eval_extract_features.py +147 -0
- scripts/evaluate_reconstruction.py +93 -0
.gitattributes
CHANGED
|
@@ -2939,3 +2939,48 @@ results/versatile_diffusion/subj01/956.png filter=lfs diff=lfs merge=lfs -text
|
|
| 2939 |
results/versatile_diffusion/subj01/948.png filter=lfs diff=lfs merge=lfs -text
|
| 2940 |
results/versatile_diffusion/subj01/954.png filter=lfs diff=lfs merge=lfs -text
|
| 2941 |
results/versatile_diffusion/subj01/946.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2939 |
results/versatile_diffusion/subj01/948.png filter=lfs diff=lfs merge=lfs -text
|
| 2940 |
results/versatile_diffusion/subj01/954.png filter=lfs diff=lfs merge=lfs -text
|
| 2941 |
results/versatile_diffusion/subj01/946.png filter=lfs diff=lfs merge=lfs -text
|
| 2942 |
+
results/versatile_diffusion/subj01/952.png filter=lfs diff=lfs merge=lfs -text
|
| 2943 |
+
results/versatile_diffusion/subj01/955.png filter=lfs diff=lfs merge=lfs -text
|
| 2944 |
+
results/versatile_diffusion/subj01/951.png filter=lfs diff=lfs merge=lfs -text
|
| 2945 |
+
results/versatile_diffusion/subj01/962.png filter=lfs diff=lfs merge=lfs -text
|
| 2946 |
+
results/versatile_diffusion/subj01/959.png filter=lfs diff=lfs merge=lfs -text
|
| 2947 |
+
results/versatile_diffusion/subj01/973.png filter=lfs diff=lfs merge=lfs -text
|
| 2948 |
+
results/versatile_diffusion/subj01/957.png filter=lfs diff=lfs merge=lfs -text
|
| 2949 |
+
results/versatile_diffusion/subj01/970.png filter=lfs diff=lfs merge=lfs -text
|
| 2950 |
+
results/versatile_diffusion/subj01/977.png filter=lfs diff=lfs merge=lfs -text
|
| 2951 |
+
results/versatile_diffusion/subj01/976.png filter=lfs diff=lfs merge=lfs -text
|
| 2952 |
+
results/versatile_diffusion/subj01/972.png filter=lfs diff=lfs merge=lfs -text
|
| 2953 |
+
results/versatile_diffusion/subj01/969.png filter=lfs diff=lfs merge=lfs -text
|
| 2954 |
+
results/versatile_diffusion/subj01/974.png filter=lfs diff=lfs merge=lfs -text
|
| 2955 |
+
results/versatile_diffusion/subj01/975.png filter=lfs diff=lfs merge=lfs -text
|
| 2956 |
+
results/versatile_diffusion/subj01/978.png filter=lfs diff=lfs merge=lfs -text
|
| 2957 |
+
results/versatile_diffusion/subj01/961.png filter=lfs diff=lfs merge=lfs -text
|
| 2958 |
+
results/versatile_diffusion/subj01/953.png filter=lfs diff=lfs merge=lfs -text
|
| 2959 |
+
results/versatile_diffusion/subj01/966.png filter=lfs diff=lfs merge=lfs -text
|
| 2960 |
+
results/versatile_diffusion/subj01/965.png filter=lfs diff=lfs merge=lfs -text
|
| 2961 |
+
results/versatile_diffusion/subj01/968.png filter=lfs diff=lfs merge=lfs -text
|
| 2962 |
+
results/versatile_diffusion/subj01/964.png filter=lfs diff=lfs merge=lfs -text
|
| 2963 |
+
results/versatile_diffusion/subj01/963.png filter=lfs diff=lfs merge=lfs -text
|
| 2964 |
+
results/versatile_diffusion/subj01/960.png filter=lfs diff=lfs merge=lfs -text
|
| 2965 |
+
results/versatile_diffusion/subj01/971.png filter=lfs diff=lfs merge=lfs -text
|
| 2966 |
+
results/versatile_diffusion/subj01/981.png filter=lfs diff=lfs merge=lfs -text
|
| 2967 |
+
results/versatile_diffusion/subj01/979.png filter=lfs diff=lfs merge=lfs -text
|
| 2968 |
+
results/versatile_diffusion/subj01/967.png filter=lfs diff=lfs merge=lfs -text
|
| 2969 |
+
results/versatile_diffusion/subj01/958.png filter=lfs diff=lfs merge=lfs -text
|
| 2970 |
+
results/versatile_diffusion/subj01/96.png filter=lfs diff=lfs merge=lfs -text
|
| 2971 |
+
results/versatile_diffusion/subj01/99.png filter=lfs diff=lfs merge=lfs -text
|
| 2972 |
+
results/versatile_diffusion/subj01/980.png filter=lfs diff=lfs merge=lfs -text
|
| 2973 |
+
results/versatile_diffusion/subj01/roi/0.png filter=lfs diff=lfs merge=lfs -text
|
| 2974 |
+
results/versatile_diffusion/subj01/roi/10.png filter=lfs diff=lfs merge=lfs -text
|
| 2975 |
+
results/versatile_diffusion/subj01/roi/1.png filter=lfs diff=lfs merge=lfs -text
|
| 2976 |
+
results/versatile_diffusion/subj01/roi/5.png filter=lfs diff=lfs merge=lfs -text
|
| 2977 |
+
results/versatile_diffusion/subj01/roi/7.png filter=lfs diff=lfs merge=lfs -text
|
| 2978 |
+
results/versatile_diffusion/subj01/roi/6.png filter=lfs diff=lfs merge=lfs -text
|
| 2979 |
+
results/versatile_diffusion/subj01/roi/11.png filter=lfs diff=lfs merge=lfs -text
|
| 2980 |
+
results/versatile_diffusion/subj01/98.png filter=lfs diff=lfs merge=lfs -text
|
| 2981 |
+
results/versatile_diffusion/subj01/roi/9.png filter=lfs diff=lfs merge=lfs -text
|
| 2982 |
+
results/versatile_diffusion/subj01/roi/8.png filter=lfs diff=lfs merge=lfs -text
|
| 2983 |
+
results/versatile_diffusion/subj01/97.png filter=lfs diff=lfs merge=lfs -text
|
| 2984 |
+
results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
|
| 2985 |
+
results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
|
| 2986 |
+
results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
|
results/versatile_diffusion/subj01/951.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/952.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/953.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/955.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/957.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/958.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/959.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/96.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/960.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/961.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/962.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/963.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/964.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/965.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/966.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/967.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/968.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/969.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/97.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/970.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/971.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/972.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/973.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/974.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/975.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/976.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/977.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/978.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/979.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/98.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/980.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/981.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/99.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/0.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/1.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/10.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/11.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/12.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/2.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/3.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/5.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/6.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/7.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/8.png
ADDED
|
Git LFS Details
|
results/versatile_diffusion/subj01/roi/9.png
ADDED
|
Git LFS Details
|
scripts/clipvision_extract_features.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('versatile_diffusion')
|
| 3 |
+
import os
|
| 4 |
+
import PIL
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from lib.cfg_helper import model_cfg_bank
|
| 10 |
+
from lib.model_zoo import get_model
|
| 11 |
+
from lib.experiments.sd_default import color_adjust, auto_merge_imlist
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
|
| 14 |
+
from lib.model_zoo.vd import VD
|
| 15 |
+
from lib.cfg_holder import cfg_unique_holder as cfguh
|
| 16 |
+
from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml
|
| 17 |
+
import torchvision.transforms as T
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
parser = argparse.ArgumentParser(description='Argument Parser')
|
| 21 |
+
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
sub=int(args.sub)
|
| 24 |
+
assert sub in [1,2,5,7]
|
| 25 |
+
|
| 26 |
+
cfgm_name = 'vd_noema'
|
| 27 |
+
|
| 28 |
+
pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth'
|
| 29 |
+
cfgm = model_cfg_bank()(cfgm_name)
|
| 30 |
+
net = get_model()(cfgm)
|
| 31 |
+
sd = torch.load(pth, map_location='cpu')
|
| 32 |
+
net.load_state_dict(sd, strict=False)
|
| 33 |
+
|
| 34 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
net.clip = net.clip.to(device)
|
| 36 |
+
|
| 37 |
+
class batch_generator_external_images(Dataset):
|
| 38 |
+
|
| 39 |
+
def __init__(self, data_path):
|
| 40 |
+
self.data_path = data_path
|
| 41 |
+
self.im = np.load(data_path).astype(np.uint8)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def __getitem__(self,idx):
|
| 45 |
+
img = Image.fromarray(self.im[idx])
|
| 46 |
+
img = T.functional.resize(img,(512,512))
|
| 47 |
+
img = T.functional.to_tensor(img).float()
|
| 48 |
+
#img = img/255
|
| 49 |
+
img = img*2 - 1
|
| 50 |
+
return img
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.im)
|
| 54 |
+
|
| 55 |
+
batch_size=1
|
| 56 |
+
image_path = 'data/processed_data/subj{:02d}/nsd_train_stim_sub{}.npy'.format(sub,sub)
|
| 57 |
+
train_images = batch_generator_external_images(data_path = image_path)
|
| 58 |
+
|
| 59 |
+
image_path = 'data/processed_data/subj{:02d}/nsd_test_stim_sub{}.npy'.format(sub,sub)
|
| 60 |
+
test_images = batch_generator_external_images(data_path = image_path)
|
| 61 |
+
|
| 62 |
+
trainloader = DataLoader(train_images,batch_size,shuffle=False)
|
| 63 |
+
testloader = DataLoader(test_images,batch_size,shuffle=False)
|
| 64 |
+
|
| 65 |
+
num_embed, num_features, num_test, num_train = 257, 768, len(test_images), len(train_images)
|
| 66 |
+
|
| 67 |
+
train_clip = np.zeros((num_train,num_embed,num_features))
|
| 68 |
+
test_clip = np.zeros((num_test,num_embed,num_features))
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
for i,cin in enumerate(testloader):
|
| 72 |
+
print(i)
|
| 73 |
+
#ctemp = cin*2 - 1
|
| 74 |
+
c = net.clip_encode_vision(cin)
|
| 75 |
+
test_clip[i] = c[0].cpu().numpy()
|
| 76 |
+
|
| 77 |
+
np.save('data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub),test_clip)
|
| 78 |
+
|
| 79 |
+
for i,cin in enumerate(trainloader):
|
| 80 |
+
print(i)
|
| 81 |
+
#ctemp = cin*2 - 1
|
| 82 |
+
c = net.clip_encode_vision(cin)
|
| 83 |
+
train_clip[i] = c[0].cpu().numpy()
|
| 84 |
+
np.save('data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub),train_clip)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
scripts/clipvision_regression.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import sklearn.linear_model as skl
|
| 4 |
+
import pickle
|
| 5 |
+
import argparse
|
| 6 |
+
parser = argparse.ArgumentParser(description='Argument Parser')
|
| 7 |
+
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
|
| 8 |
+
args = parser.parse_args()
|
| 9 |
+
sub=int(args.sub)
|
| 10 |
+
assert sub in [1,2,5,7]
|
| 11 |
+
|
| 12 |
+
train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
|
| 13 |
+
train_fmri = np.load(train_path)
|
| 14 |
+
test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
|
| 15 |
+
test_fmri = np.load(test_path)
|
| 16 |
+
|
| 17 |
+
## Preprocessing fMRI
|
| 18 |
+
|
| 19 |
+
train_fmri = train_fmri/300
|
| 20 |
+
test_fmri = test_fmri/300
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
norm_mean_train = np.mean(train_fmri, axis=0)
|
| 24 |
+
norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
|
| 25 |
+
train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
|
| 26 |
+
test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
|
| 27 |
+
|
| 28 |
+
print(np.mean(train_fmri),np.std(train_fmri))
|
| 29 |
+
print(np.mean(test_fmri),np.std(test_fmri))
|
| 30 |
+
|
| 31 |
+
print(np.max(train_fmri),np.min(train_fmri))
|
| 32 |
+
print(np.max(test_fmri),np.min(test_fmri))
|
| 33 |
+
|
| 34 |
+
num_voxels, num_train, num_test = train_fmri.shape[1], len(train_fmri), len(test_fmri)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
train_clip = np.load('data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub))
|
| 38 |
+
test_clip = np.load('data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub))
|
| 39 |
+
|
| 40 |
+
#train_clip = train_clip[:,1:,:]
|
| 41 |
+
num_samples,num_embed,num_dim = train_clip.shape
|
| 42 |
+
|
| 43 |
+
print("Training Regression")
|
| 44 |
+
reg_w = np.zeros((num_embed,num_dim,num_voxels)).astype(np.float32)
|
| 45 |
+
reg_b = np.zeros((num_embed,num_dim)).astype(np.float32)
|
| 46 |
+
pred_clip = np.zeros_like(test_clip)
|
| 47 |
+
for i in range(num_embed):
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
reg = skl.Ridge(alpha=60000, max_iter=50000, fit_intercept=True)
|
| 51 |
+
reg.fit(train_fmri, train_clip[:,i])
|
| 52 |
+
reg_w[i] = reg.coef_
|
| 53 |
+
reg_b[i] = reg.intercept_
|
| 54 |
+
|
| 55 |
+
pred_test_latent = reg.predict(test_fmri)
|
| 56 |
+
std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
|
| 57 |
+
pred_clip[:,i] = std_norm_test_latent * np.std(train_clip[:,i],axis=0) + np.mean(train_clip[:,i],axis=0)
|
| 58 |
+
|
| 59 |
+
print(i,reg.score(test_fmri,test_clip[:,i]))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
np.save('data/predicted_features/subj{:02d}/nsd_clipvision_predtest_nsdgeneral.npy'.format(sub),pred_clip)
|
| 63 |
+
|
| 64 |
+
datadict = {
|
| 65 |
+
'weight' : reg_w,
|
| 66 |
+
'bias' : reg_b,
|
| 67 |
+
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
with open('data/regression_weights/subj{:02d}/clipvision_regression_weights.pkl'.format(sub),"wb") as f:
|
| 71 |
+
pickle.dump(datadict,f)
|
scripts/eval_extract_features.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
import scipy.io as spio
|
| 6 |
+
import nibabel as nib
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision
|
| 10 |
+
import torchvision.models as tvmodels
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
import torchvision.transforms as T
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import clip
|
| 16 |
+
|
| 17 |
+
import skimage.io as sio
|
| 18 |
+
from skimage import data, img_as_float
|
| 19 |
+
from skimage.transform import resize as imresize
|
| 20 |
+
from skimage.metrics import structural_similarity as ssim
|
| 21 |
+
import scipy as sp
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
parser = argparse.ArgumentParser(description='Argument Parser')
|
| 25 |
+
parser.add_argument("-sub", "--sub", help="Subject Number", default=1)
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
sub = int(args.sub)
|
| 28 |
+
assert sub in [0, 1, 2, 5, 7]
|
| 29 |
+
|
| 30 |
+
images_dir = 'data/nsddata_stimuli/test_images'
|
| 31 |
+
feats_dir = 'data/eval_features/test_images'
|
| 32 |
+
|
| 33 |
+
if sub in [1, 2, 5, 7]:
|
| 34 |
+
feats_dir = f'data/eval_features/subj{sub:02d}'
|
| 35 |
+
images_dir = f'results/versatile_diffusion/subj{sub:02d}'
|
| 36 |
+
|
| 37 |
+
if not os.path.exists(feats_dir):
|
| 38 |
+
os.makedirs(feats_dir)
|
| 39 |
+
|
| 40 |
+
class batch_generator_external_images(Dataset):
|
| 41 |
+
def __init__(self, data_path='', prefix='', net_name='clip'):
|
| 42 |
+
self.data_path = data_path
|
| 43 |
+
self.prefix = prefix
|
| 44 |
+
self.net_name = net_name
|
| 45 |
+
|
| 46 |
+
if self.net_name == 'clip':
|
| 47 |
+
self.normalize = transforms.Normalize(
|
| 48 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
| 49 |
+
std=[0.26862954, 0.26130258, 0.27577711]
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
self.normalize = transforms.Normalize(
|
| 53 |
+
mean=[0.485, 0.456, 0.406],
|
| 54 |
+
std=[0.229, 0.224, 0.225]
|
| 55 |
+
)
|
| 56 |
+
self.num_test = 982
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
img = Image.open(f'{self.data_path}/{self.prefix}{idx}.png')
|
| 60 |
+
img = T.functional.resize(img, (224, 224))
|
| 61 |
+
img = T.functional.to_tensor(img).float()
|
| 62 |
+
img = self.normalize(img)
|
| 63 |
+
return img
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return self.num_test
|
| 67 |
+
|
| 68 |
+
# Set device
|
| 69 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 70 |
+
|
| 71 |
+
global feat_list
|
| 72 |
+
feat_list = []
|
| 73 |
+
|
| 74 |
+
def fn(module, inputs, outputs):
|
| 75 |
+
feat_list.append(outputs.cpu().numpy())
|
| 76 |
+
|
| 77 |
+
net_list = [
|
| 78 |
+
('inceptionv3', 'avgpool'),
|
| 79 |
+
('clip', 'final'),
|
| 80 |
+
('alexnet', 2),
|
| 81 |
+
('alexnet', 5),
|
| 82 |
+
('efficientnet', 'avgpool'),
|
| 83 |
+
('swav', 'avgpool')
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
batchsize = 64
|
| 87 |
+
|
| 88 |
+
for (net_name, layer) in net_list:
|
| 89 |
+
feat_list = []
|
| 90 |
+
print(net_name, layer)
|
| 91 |
+
|
| 92 |
+
dataset = batch_generator_external_images(data_path=images_dir, net_name=net_name, prefix='')
|
| 93 |
+
loader = DataLoader(dataset, batchsize, shuffle=False)
|
| 94 |
+
|
| 95 |
+
if net_name == 'inceptionv3':
|
| 96 |
+
net = tvmodels.inception_v3(pretrained=True)
|
| 97 |
+
if layer == 'avgpool':
|
| 98 |
+
net.avgpool.register_forward_hook(fn)
|
| 99 |
+
elif layer == 'lastconv':
|
| 100 |
+
net.Mixed_7c.register_forward_hook(fn)
|
| 101 |
+
|
| 102 |
+
elif net_name == 'alexnet':
|
| 103 |
+
net = tvmodels.alexnet(pretrained=True)
|
| 104 |
+
if layer == 2:
|
| 105 |
+
net.features[4].register_forward_hook(fn)
|
| 106 |
+
elif layer == 5:
|
| 107 |
+
net.features[11].register_forward_hook(fn)
|
| 108 |
+
elif layer == 7:
|
| 109 |
+
net.classifier[5].register_forward_hook(fn)
|
| 110 |
+
|
| 111 |
+
elif net_name == 'clip':
|
| 112 |
+
model, _ = clip.load("ViT-L/14", device=device)
|
| 113 |
+
net = model.visual.to(torch.float32)
|
| 114 |
+
if layer == 7:
|
| 115 |
+
net.transformer.resblocks[7].register_forward_hook(fn)
|
| 116 |
+
elif layer == 12:
|
| 117 |
+
net.transformer.resblocks[12].register_forward_hook(fn)
|
| 118 |
+
elif layer == 'final':
|
| 119 |
+
net.register_forward_hook(fn)
|
| 120 |
+
|
| 121 |
+
elif net_name == 'efficientnet':
|
| 122 |
+
net = tvmodels.efficientnet_b1(weights='IMAGENET1K_V1')
|
| 123 |
+
net.avgpool.register_forward_hook(fn)
|
| 124 |
+
|
| 125 |
+
elif net_name == 'swav':
|
| 126 |
+
net = torch.hub.load('facebookresearch/swav:main', 'resnet50')
|
| 127 |
+
net.avgpool.register_forward_hook(fn)
|
| 128 |
+
|
| 129 |
+
net.eval()
|
| 130 |
+
net = net.to(device)
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
for i, x in enumerate(loader):
|
| 134 |
+
print(i * batchsize)
|
| 135 |
+
x = x.to(device)
|
| 136 |
+
_ = net(x)
|
| 137 |
+
|
| 138 |
+
if net_name == 'clip':
|
| 139 |
+
if layer == 7 or layer == 12:
|
| 140 |
+
feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2))
|
| 141 |
+
else:
|
| 142 |
+
feat_list = np.concatenate(feat_list)
|
| 143 |
+
else:
|
| 144 |
+
feat_list = np.concatenate(feat_list)
|
| 145 |
+
|
| 146 |
+
file_name = f'{feats_dir}/{net_name}_{layer}.npy'
|
| 147 |
+
np.save(file_name, feat_list)
|
scripts/evaluate_reconstruction.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import h5py
|
| 5 |
+
import scipy.io as spio
|
| 6 |
+
import nibabel as nib
|
| 7 |
+
import scipy as sp
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
parser = argparse.ArgumentParser(description='Argument Parser')
|
| 14 |
+
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
sub=int(args.sub)
|
| 17 |
+
assert sub in [1,2,5,7]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from scipy.stats import pearsonr,binom,linregress
|
| 21 |
+
import numpy as np
|
| 22 |
+
def pairwise_corr_all(ground_truth, predictions):
|
| 23 |
+
r = np.corrcoef(ground_truth, predictions)#cosine_similarity(ground_truth, predictions)#
|
| 24 |
+
r = r[:len(ground_truth), len(ground_truth):] # rows: groundtruth, columns: predicitons
|
| 25 |
+
#print(r.shape)
|
| 26 |
+
# congruent pairs are on diagonal
|
| 27 |
+
congruents = np.diag(r)
|
| 28 |
+
#print(congruents)
|
| 29 |
+
|
| 30 |
+
# for each column (predicition) we should count the number of rows (groundtruth) that the value is lower than the congruent (e.g. success).
|
| 31 |
+
success = r < congruents
|
| 32 |
+
success_cnt = np.sum(success, 0)
|
| 33 |
+
|
| 34 |
+
# note: diagonal of 'success' is always zero so we can discard it. That's why we divide by len-1
|
| 35 |
+
perf = np.mean(success_cnt) / (len(ground_truth)-1)
|
| 36 |
+
p = 1 - binom.cdf(perf*len(ground_truth)*(len(ground_truth)-1), len(ground_truth)*(len(ground_truth)-1), 0.5)
|
| 37 |
+
|
| 38 |
+
return perf, p
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
net_list = [
|
| 42 |
+
('inceptionv3','avgpool'),
|
| 43 |
+
('clip','final'),
|
| 44 |
+
('alexnet',2),
|
| 45 |
+
('alexnet',5),
|
| 46 |
+
('efficientnet','avgpool'),
|
| 47 |
+
('swav','avgpool')
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
feats_dir = 'data/eval_features/subj{:02d}'.format(sub)
|
| 51 |
+
test_dir = 'data/eval_features/test_images'
|
| 52 |
+
num_test = 982
|
| 53 |
+
distance_fn = sp.spatial.distance.correlation
|
| 54 |
+
pairwise_corrs = []
|
| 55 |
+
for (net_name,layer) in net_list:
|
| 56 |
+
file_name = '{}/{}_{}.npy'.format(test_dir,net_name,layer)
|
| 57 |
+
gt_feat = np.load(file_name)
|
| 58 |
+
|
| 59 |
+
file_name = '{}/{}_{}.npy'.format(feats_dir,net_name,layer)
|
| 60 |
+
eval_feat = np.load(file_name)
|
| 61 |
+
|
| 62 |
+
gt_feat = gt_feat.reshape((len(gt_feat),-1))
|
| 63 |
+
eval_feat = eval_feat.reshape((len(eval_feat),-1))
|
| 64 |
+
|
| 65 |
+
print(net_name,layer)
|
| 66 |
+
if net_name in ['efficientnet','swav']:
|
| 67 |
+
print('distance: ',np.array([distance_fn(gt_feat[i],eval_feat[i]) for i in range(num_test)]).mean())
|
| 68 |
+
else:
|
| 69 |
+
pairwise_corrs.append(pairwise_corr_all(gt_feat[:num_test],eval_feat[:num_test])[0])
|
| 70 |
+
print('pairwise corr: ',pairwise_corrs[-1])
|
| 71 |
+
|
| 72 |
+
from skimage.color import rgb2gray
|
| 73 |
+
from skimage.metrics import structural_similarity as ssim
|
| 74 |
+
|
| 75 |
+
ssim_list = []
|
| 76 |
+
pixcorr_list = []
|
| 77 |
+
for i in range(982):
|
| 78 |
+
gen_image = Image.open('results/versatile_diffusion/subj{:02d}/{}.png'.format(sub,i)).resize((425,425))
|
| 79 |
+
gt_image = Image.open('data/nsddata_stimuli/test_images/{}.png'.format(i))
|
| 80 |
+
gen_image = np.array(gen_image)/255.0
|
| 81 |
+
gt_image = np.array(gt_image)/255.0
|
| 82 |
+
pixcorr_res = np.corrcoef(gt_image.reshape(1,-1), gen_image.reshape(1,-1))[0,1]
|
| 83 |
+
pixcorr_list.append(pixcorr_res)
|
| 84 |
+
gen_image = rgb2gray(gen_image)
|
| 85 |
+
gt_image = rgb2gray(gt_image)
|
| 86 |
+
ssim_res = ssim(gen_image, gt_image, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)
|
| 87 |
+
ssim_list.append(ssim_res)
|
| 88 |
+
|
| 89 |
+
ssim_list = np.array(ssim_list)
|
| 90 |
+
pixcorr_list = np.array(pixcorr_list)
|
| 91 |
+
print('PixCorr: {}'.format(pixcorr_list.mean()))
|
| 92 |
+
print('SSIM: {}'.format(ssim_list.mean()))
|
| 93 |
+
|