dineshsai07 commited on
Commit
46a8d8a
·
verified ·
1 Parent(s): 01f47a8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +45 -0
  2. results/versatile_diffusion/subj01/951.png +3 -0
  3. results/versatile_diffusion/subj01/952.png +3 -0
  4. results/versatile_diffusion/subj01/953.png +3 -0
  5. results/versatile_diffusion/subj01/955.png +3 -0
  6. results/versatile_diffusion/subj01/957.png +3 -0
  7. results/versatile_diffusion/subj01/958.png +3 -0
  8. results/versatile_diffusion/subj01/959.png +3 -0
  9. results/versatile_diffusion/subj01/96.png +3 -0
  10. results/versatile_diffusion/subj01/960.png +3 -0
  11. results/versatile_diffusion/subj01/961.png +3 -0
  12. results/versatile_diffusion/subj01/962.png +3 -0
  13. results/versatile_diffusion/subj01/963.png +3 -0
  14. results/versatile_diffusion/subj01/964.png +3 -0
  15. results/versatile_diffusion/subj01/965.png +3 -0
  16. results/versatile_diffusion/subj01/966.png +3 -0
  17. results/versatile_diffusion/subj01/967.png +3 -0
  18. results/versatile_diffusion/subj01/968.png +3 -0
  19. results/versatile_diffusion/subj01/969.png +3 -0
  20. results/versatile_diffusion/subj01/97.png +3 -0
  21. results/versatile_diffusion/subj01/970.png +3 -0
  22. results/versatile_diffusion/subj01/971.png +3 -0
  23. results/versatile_diffusion/subj01/972.png +3 -0
  24. results/versatile_diffusion/subj01/973.png +3 -0
  25. results/versatile_diffusion/subj01/974.png +3 -0
  26. results/versatile_diffusion/subj01/975.png +3 -0
  27. results/versatile_diffusion/subj01/976.png +3 -0
  28. results/versatile_diffusion/subj01/977.png +3 -0
  29. results/versatile_diffusion/subj01/978.png +3 -0
  30. results/versatile_diffusion/subj01/979.png +3 -0
  31. results/versatile_diffusion/subj01/98.png +3 -0
  32. results/versatile_diffusion/subj01/980.png +3 -0
  33. results/versatile_diffusion/subj01/981.png +3 -0
  34. results/versatile_diffusion/subj01/99.png +3 -0
  35. results/versatile_diffusion/subj01/roi/0.png +3 -0
  36. results/versatile_diffusion/subj01/roi/1.png +3 -0
  37. results/versatile_diffusion/subj01/roi/10.png +3 -0
  38. results/versatile_diffusion/subj01/roi/11.png +3 -0
  39. results/versatile_diffusion/subj01/roi/12.png +3 -0
  40. results/versatile_diffusion/subj01/roi/2.png +3 -0
  41. results/versatile_diffusion/subj01/roi/3.png +3 -0
  42. results/versatile_diffusion/subj01/roi/5.png +3 -0
  43. results/versatile_diffusion/subj01/roi/6.png +3 -0
  44. results/versatile_diffusion/subj01/roi/7.png +3 -0
  45. results/versatile_diffusion/subj01/roi/8.png +3 -0
  46. results/versatile_diffusion/subj01/roi/9.png +3 -0
  47. scripts/clipvision_extract_features.py +88 -0
  48. scripts/clipvision_regression.py +71 -0
  49. scripts/eval_extract_features.py +147 -0
  50. scripts/evaluate_reconstruction.py +93 -0
.gitattributes CHANGED
@@ -2939,3 +2939,48 @@ results/versatile_diffusion/subj01/956.png filter=lfs diff=lfs merge=lfs -text
2939
  results/versatile_diffusion/subj01/948.png filter=lfs diff=lfs merge=lfs -text
2940
  results/versatile_diffusion/subj01/954.png filter=lfs diff=lfs merge=lfs -text
2941
  results/versatile_diffusion/subj01/946.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2939
  results/versatile_diffusion/subj01/948.png filter=lfs diff=lfs merge=lfs -text
2940
  results/versatile_diffusion/subj01/954.png filter=lfs diff=lfs merge=lfs -text
2941
  results/versatile_diffusion/subj01/946.png filter=lfs diff=lfs merge=lfs -text
2942
+ results/versatile_diffusion/subj01/952.png filter=lfs diff=lfs merge=lfs -text
2943
+ results/versatile_diffusion/subj01/955.png filter=lfs diff=lfs merge=lfs -text
2944
+ results/versatile_diffusion/subj01/951.png filter=lfs diff=lfs merge=lfs -text
2945
+ results/versatile_diffusion/subj01/962.png filter=lfs diff=lfs merge=lfs -text
2946
+ results/versatile_diffusion/subj01/959.png filter=lfs diff=lfs merge=lfs -text
2947
+ results/versatile_diffusion/subj01/973.png filter=lfs diff=lfs merge=lfs -text
2948
+ results/versatile_diffusion/subj01/957.png filter=lfs diff=lfs merge=lfs -text
2949
+ results/versatile_diffusion/subj01/970.png filter=lfs diff=lfs merge=lfs -text
2950
+ results/versatile_diffusion/subj01/977.png filter=lfs diff=lfs merge=lfs -text
2951
+ results/versatile_diffusion/subj01/976.png filter=lfs diff=lfs merge=lfs -text
2952
+ results/versatile_diffusion/subj01/972.png filter=lfs diff=lfs merge=lfs -text
2953
+ results/versatile_diffusion/subj01/969.png filter=lfs diff=lfs merge=lfs -text
2954
+ results/versatile_diffusion/subj01/974.png filter=lfs diff=lfs merge=lfs -text
2955
+ results/versatile_diffusion/subj01/975.png filter=lfs diff=lfs merge=lfs -text
2956
+ results/versatile_diffusion/subj01/978.png filter=lfs diff=lfs merge=lfs -text
2957
+ results/versatile_diffusion/subj01/961.png filter=lfs diff=lfs merge=lfs -text
2958
+ results/versatile_diffusion/subj01/953.png filter=lfs diff=lfs merge=lfs -text
2959
+ results/versatile_diffusion/subj01/966.png filter=lfs diff=lfs merge=lfs -text
2960
+ results/versatile_diffusion/subj01/965.png filter=lfs diff=lfs merge=lfs -text
2961
+ results/versatile_diffusion/subj01/968.png filter=lfs diff=lfs merge=lfs -text
2962
+ results/versatile_diffusion/subj01/964.png filter=lfs diff=lfs merge=lfs -text
2963
+ results/versatile_diffusion/subj01/963.png filter=lfs diff=lfs merge=lfs -text
2964
+ results/versatile_diffusion/subj01/960.png filter=lfs diff=lfs merge=lfs -text
2965
+ results/versatile_diffusion/subj01/971.png filter=lfs diff=lfs merge=lfs -text
2966
+ results/versatile_diffusion/subj01/981.png filter=lfs diff=lfs merge=lfs -text
2967
+ results/versatile_diffusion/subj01/979.png filter=lfs diff=lfs merge=lfs -text
2968
+ results/versatile_diffusion/subj01/967.png filter=lfs diff=lfs merge=lfs -text
2969
+ results/versatile_diffusion/subj01/958.png filter=lfs diff=lfs merge=lfs -text
2970
+ results/versatile_diffusion/subj01/96.png filter=lfs diff=lfs merge=lfs -text
2971
+ results/versatile_diffusion/subj01/99.png filter=lfs diff=lfs merge=lfs -text
2972
+ results/versatile_diffusion/subj01/980.png filter=lfs diff=lfs merge=lfs -text
2973
+ results/versatile_diffusion/subj01/roi/0.png filter=lfs diff=lfs merge=lfs -text
2974
+ results/versatile_diffusion/subj01/roi/10.png filter=lfs diff=lfs merge=lfs -text
2975
+ results/versatile_diffusion/subj01/roi/1.png filter=lfs diff=lfs merge=lfs -text
2976
+ results/versatile_diffusion/subj01/roi/5.png filter=lfs diff=lfs merge=lfs -text
2977
+ results/versatile_diffusion/subj01/roi/7.png filter=lfs diff=lfs merge=lfs -text
2978
+ results/versatile_diffusion/subj01/roi/6.png filter=lfs diff=lfs merge=lfs -text
2979
+ results/versatile_diffusion/subj01/roi/11.png filter=lfs diff=lfs merge=lfs -text
2980
+ results/versatile_diffusion/subj01/98.png filter=lfs diff=lfs merge=lfs -text
2981
+ results/versatile_diffusion/subj01/roi/9.png filter=lfs diff=lfs merge=lfs -text
2982
+ results/versatile_diffusion/subj01/roi/8.png filter=lfs diff=lfs merge=lfs -text
2983
+ results/versatile_diffusion/subj01/97.png filter=lfs diff=lfs merge=lfs -text
2984
+ results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -text
2985
+ results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
2986
+ results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
results/versatile_diffusion/subj01/951.png ADDED

Git LFS Details

  • SHA256: a276bff91f94f939c33f1608e51d3bd06f02747117ee107990187eef76429558
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
results/versatile_diffusion/subj01/952.png ADDED

Git LFS Details

  • SHA256: 66e9f74cfe01cec8c7c9ece3a8784f34d27b22e8248a4a478bc5bcc74fed7b74
  • Pointer size: 131 Bytes
  • Size of remote file: 185 kB
results/versatile_diffusion/subj01/953.png ADDED

Git LFS Details

  • SHA256: f3b38cc8b20770b656fe37178811ba00dffe09ebe665ea29a91694f803cc2ebb
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
results/versatile_diffusion/subj01/955.png ADDED

Git LFS Details

  • SHA256: 1ea235d76e79fb734ee8b862393c048d1554ae8d053f3a9cbf0362f4518545c3
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
results/versatile_diffusion/subj01/957.png ADDED

Git LFS Details

  • SHA256: b7e922cff64da16de13e21e64f3c25941278d35588d9be4b8a81731f6aa71079
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
results/versatile_diffusion/subj01/958.png ADDED

Git LFS Details

  • SHA256: 6cd66e3e796fb3d89076cf5871d3683dcb0cef334709b5ada16bb31188309f01
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB
results/versatile_diffusion/subj01/959.png ADDED

Git LFS Details

  • SHA256: 39e42df5009a12127c1806f1776622ae4d058e902d6e0b3a3ad7053db62a5669
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
results/versatile_diffusion/subj01/96.png ADDED

Git LFS Details

  • SHA256: 06ec138e007b81043823e29d5bc6f02c8d75640087f72898b38bc8bf957be2d1
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
results/versatile_diffusion/subj01/960.png ADDED

Git LFS Details

  • SHA256: ecc44614c3ff43722cf66caed6442383a764125408a1943b87e611534cb6bd31
  • Pointer size: 131 Bytes
  • Size of remote file: 279 kB
results/versatile_diffusion/subj01/961.png ADDED

Git LFS Details

  • SHA256: e752df40cdcd7e7b14f865cf89f6dd48209665fe7e2880389cfbcf17f6deeb77
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
results/versatile_diffusion/subj01/962.png ADDED

Git LFS Details

  • SHA256: 544ef8d7e98eac3b2a0e0500168345e4927774bd8fb4ab3f71092ffc02dde850
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
results/versatile_diffusion/subj01/963.png ADDED

Git LFS Details

  • SHA256: f5099a79dbd878dab61ae9ab059b5789b3e3f6a5ac7d75f27e82e6cd0d1b8b1c
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
results/versatile_diffusion/subj01/964.png ADDED

Git LFS Details

  • SHA256: 77ca6a4fbcfc0aaf281180760046a612ed5b07b9dad75c69102c0d5047532687
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
results/versatile_diffusion/subj01/965.png ADDED

Git LFS Details

  • SHA256: e4492a6d242a432ff5ac00b0bcc1e79fc2c5e9391f563870a71a129d17bf72d4
  • Pointer size: 131 Bytes
  • Size of remote file: 265 kB
results/versatile_diffusion/subj01/966.png ADDED

Git LFS Details

  • SHA256: 63cbe22f74767bdd5a1ef52d415960a404c6a3f7a7b551451eed382e7b1d5812
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
results/versatile_diffusion/subj01/967.png ADDED

Git LFS Details

  • SHA256: 9bb95ed959d2b1010e67be2265623abb779d327b5ed5815d2dcd727215ccd5dc
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
results/versatile_diffusion/subj01/968.png ADDED

Git LFS Details

  • SHA256: e0f73300c482c0a48bf2fde120485b7715bd94d9e91ecef08ff9183f0ea9a15e
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
results/versatile_diffusion/subj01/969.png ADDED

Git LFS Details

  • SHA256: d7b56b9e0872df9ec48055b7330db8382dd4bc54dd7cc64d098f3925415598ee
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
results/versatile_diffusion/subj01/97.png ADDED

Git LFS Details

  • SHA256: f2d124b709f7c62075af45321a321b95bbeb6bd257632a18bdd4640b424d6ce1
  • Pointer size: 131 Bytes
  • Size of remote file: 226 kB
results/versatile_diffusion/subj01/970.png ADDED

Git LFS Details

  • SHA256: 27b9a49344624450f1c185f3914228d1f0486035ae913114a1627e7a118bee22
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
results/versatile_diffusion/subj01/971.png ADDED

Git LFS Details

  • SHA256: 0a224f7cd8dd19cb49529a4b9fb53400eb822806e748bfb153d037f7cf2eea46
  • Pointer size: 131 Bytes
  • Size of remote file: 285 kB
results/versatile_diffusion/subj01/972.png ADDED

Git LFS Details

  • SHA256: 94b19a719c14fa005c8d744d7a112dbbcb3f38d273afcedcc793d7fdac6bf802
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB
results/versatile_diffusion/subj01/973.png ADDED

Git LFS Details

  • SHA256: 9cfba086b2a4041f7bbafd06f1bbd66cf6de5a5607dd274ecf667f635f495b1d
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
results/versatile_diffusion/subj01/974.png ADDED

Git LFS Details

  • SHA256: 830f19987d783b804d0df861370b75ffe607596069ec1ae877beb14c3ae72157
  • Pointer size: 131 Bytes
  • Size of remote file: 242 kB
results/versatile_diffusion/subj01/975.png ADDED

Git LFS Details

  • SHA256: 06ae0b1cce0b56f7303a0540c564bee9df35e381ec77409a0b2b40ace1eb821f
  • Pointer size: 131 Bytes
  • Size of remote file: 226 kB
results/versatile_diffusion/subj01/976.png ADDED

Git LFS Details

  • SHA256: ba24b7aa30bc5f9ff1e7b44acba307d30f0e4f165e2aa729ae374049deb0c22c
  • Pointer size: 131 Bytes
  • Size of remote file: 229 kB
results/versatile_diffusion/subj01/977.png ADDED

Git LFS Details

  • SHA256: fde29bcff9cc1d38cc55ea7f38ed9f44d55bfd07c036e5a6bb133b46aaceef4a
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
results/versatile_diffusion/subj01/978.png ADDED

Git LFS Details

  • SHA256: ef050df1901835631d643f4d4a964c0d50a24cccffbd732a950e724431b8539d
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB
results/versatile_diffusion/subj01/979.png ADDED

Git LFS Details

  • SHA256: 255bbbb60eba158583607528b1336f6ba20b7dff24944f395078b42b243657c9
  • Pointer size: 131 Bytes
  • Size of remote file: 242 kB
results/versatile_diffusion/subj01/98.png ADDED

Git LFS Details

  • SHA256: bde1724856d5f7f7399f4d2da7187fcc11e72dad1fb3759b7e7d4623b2fc1594
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
results/versatile_diffusion/subj01/980.png ADDED

Git LFS Details

  • SHA256: 4337b52af8f7f7e48ef3e909680c0841a47432c4194b0326f468fe14eabe6387
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB
results/versatile_diffusion/subj01/981.png ADDED

Git LFS Details

  • SHA256: b5f36dd6bc61be4e3072394e1ff58b8a7a559bc6853dd5f378d6b5f5209d5295
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
results/versatile_diffusion/subj01/99.png ADDED

Git LFS Details

  • SHA256: 68beb2c299114e00365d06cce2c5a02b640e4d77731d14ed0796660bad574dcf
  • Pointer size: 131 Bytes
  • Size of remote file: 205 kB
results/versatile_diffusion/subj01/roi/0.png ADDED

Git LFS Details

  • SHA256: 24392e1b50eccbbe1e16c9e2d4f715ff58bf87b087ea72938d91d669587f2748
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
results/versatile_diffusion/subj01/roi/1.png ADDED

Git LFS Details

  • SHA256: 45d77380be8c7ce03a7180b6f0a069daaa639f210fa9741558abc1bb1bbd03de
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB
results/versatile_diffusion/subj01/roi/10.png ADDED

Git LFS Details

  • SHA256: eeefacce4476ae9835582b9539bb9b1954a5eaa16c4e5106f3d32f0a8fd892a9
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
results/versatile_diffusion/subj01/roi/11.png ADDED

Git LFS Details

  • SHA256: c4e951a75cd75bdd015555708e8a3bbb7f500d1441affa3512730f3bbe275fa3
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
results/versatile_diffusion/subj01/roi/12.png ADDED

Git LFS Details

  • SHA256: ac09d06bc56767682a31c3c8706c2d3b3536d31f88546f04c2769eb77f8337c8
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
results/versatile_diffusion/subj01/roi/2.png ADDED

Git LFS Details

  • SHA256: f0d27ab8a45d9a8719ce0c338cd011c036649757c91a1db3a7a1cbabe7d27862
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
results/versatile_diffusion/subj01/roi/3.png ADDED

Git LFS Details

  • SHA256: e6fdadef7cdb35c9c4eb82e88a1fd2f0c346917d09d0c77c3100b95042c6101f
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
results/versatile_diffusion/subj01/roi/5.png ADDED

Git LFS Details

  • SHA256: 09e652ac2bee6b6a7891cc7e9cc50272656d18f6349fc93527657d4f0135020d
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
results/versatile_diffusion/subj01/roi/6.png ADDED

Git LFS Details

  • SHA256: e65a11ba344d030f48dd07462cbac3b98fc0257044a3bd4211248ff33cce2bd0
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
results/versatile_diffusion/subj01/roi/7.png ADDED

Git LFS Details

  • SHA256: 9b547991c126290aaf36a45ebdd4cc9d1e576be9b86955e017675718a96e6206
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
results/versatile_diffusion/subj01/roi/8.png ADDED

Git LFS Details

  • SHA256: 5c17f3bcc7e97d9ef5b2733c06850e6b9f4f2dba1e987774880f4b2f32431a99
  • Pointer size: 131 Bytes
  • Size of remote file: 277 kB
results/versatile_diffusion/subj01/roi/9.png ADDED

Git LFS Details

  • SHA256: 05b9bb072d328059b954a2bc79243783811d8b88859ec8f54700aa625dbc88c1
  • Pointer size: 131 Bytes
  • Size of remote file: 279 kB
scripts/clipvision_extract_features.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('versatile_diffusion')
3
+ import os
4
+ import PIL
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ import torch
9
+ from lib.cfg_helper import model_cfg_bank
10
+ from lib.model_zoo import get_model
11
+ from lib.experiments.sd_default import color_adjust, auto_merge_imlist
12
+ from torch.utils.data import DataLoader, Dataset
13
+
14
+ from lib.model_zoo.vd import VD
15
+ from lib.cfg_holder import cfg_unique_holder as cfguh
16
+ from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml
17
+ import torchvision.transforms as T
18
+
19
+ import argparse
20
+ parser = argparse.ArgumentParser(description='Argument Parser')
21
+ parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
22
+ args = parser.parse_args()
23
+ sub=int(args.sub)
24
+ assert sub in [1,2,5,7]
25
+
26
+ cfgm_name = 'vd_noema'
27
+
28
+ pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth'
29
+ cfgm = model_cfg_bank()(cfgm_name)
30
+ net = get_model()(cfgm)
31
+ sd = torch.load(pth, map_location='cpu')
32
+ net.load_state_dict(sd, strict=False)
33
+
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ net.clip = net.clip.to(device)
36
+
37
+ class batch_generator_external_images(Dataset):
38
+
39
+ def __init__(self, data_path):
40
+ self.data_path = data_path
41
+ self.im = np.load(data_path).astype(np.uint8)
42
+
43
+
44
+ def __getitem__(self,idx):
45
+ img = Image.fromarray(self.im[idx])
46
+ img = T.functional.resize(img,(512,512))
47
+ img = T.functional.to_tensor(img).float()
48
+ #img = img/255
49
+ img = img*2 - 1
50
+ return img
51
+
52
+ def __len__(self):
53
+ return len(self.im)
54
+
55
+ batch_size=1
56
+ image_path = 'data/processed_data/subj{:02d}/nsd_train_stim_sub{}.npy'.format(sub,sub)
57
+ train_images = batch_generator_external_images(data_path = image_path)
58
+
59
+ image_path = 'data/processed_data/subj{:02d}/nsd_test_stim_sub{}.npy'.format(sub,sub)
60
+ test_images = batch_generator_external_images(data_path = image_path)
61
+
62
+ trainloader = DataLoader(train_images,batch_size,shuffle=False)
63
+ testloader = DataLoader(test_images,batch_size,shuffle=False)
64
+
65
+ num_embed, num_features, num_test, num_train = 257, 768, len(test_images), len(train_images)
66
+
67
+ train_clip = np.zeros((num_train,num_embed,num_features))
68
+ test_clip = np.zeros((num_test,num_embed,num_features))
69
+
70
+ with torch.no_grad():
71
+ for i,cin in enumerate(testloader):
72
+ print(i)
73
+ #ctemp = cin*2 - 1
74
+ c = net.clip_encode_vision(cin)
75
+ test_clip[i] = c[0].cpu().numpy()
76
+
77
+ np.save('data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub),test_clip)
78
+
79
+ for i,cin in enumerate(trainloader):
80
+ print(i)
81
+ #ctemp = cin*2 - 1
82
+ c = net.clip_encode_vision(cin)
83
+ train_clip[i] = c[0].cpu().numpy()
84
+ np.save('data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub),train_clip)
85
+
86
+
87
+
88
+
scripts/clipvision_regression.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import sklearn.linear_model as skl
4
+ import pickle
5
+ import argparse
6
+ parser = argparse.ArgumentParser(description='Argument Parser')
7
+ parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
8
+ args = parser.parse_args()
9
+ sub=int(args.sub)
10
+ assert sub in [1,2,5,7]
11
+
12
+ train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
13
+ train_fmri = np.load(train_path)
14
+ test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
15
+ test_fmri = np.load(test_path)
16
+
17
+ ## Preprocessing fMRI
18
+
19
+ train_fmri = train_fmri/300
20
+ test_fmri = test_fmri/300
21
+
22
+
23
+ norm_mean_train = np.mean(train_fmri, axis=0)
24
+ norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
25
+ train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
26
+ test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
27
+
28
+ print(np.mean(train_fmri),np.std(train_fmri))
29
+ print(np.mean(test_fmri),np.std(test_fmri))
30
+
31
+ print(np.max(train_fmri),np.min(train_fmri))
32
+ print(np.max(test_fmri),np.min(test_fmri))
33
+
34
+ num_voxels, num_train, num_test = train_fmri.shape[1], len(train_fmri), len(test_fmri)
35
+
36
+
37
+ train_clip = np.load('data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub))
38
+ test_clip = np.load('data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub))
39
+
40
+ #train_clip = train_clip[:,1:,:]
41
+ num_samples,num_embed,num_dim = train_clip.shape
42
+
43
+ print("Training Regression")
44
+ reg_w = np.zeros((num_embed,num_dim,num_voxels)).astype(np.float32)
45
+ reg_b = np.zeros((num_embed,num_dim)).astype(np.float32)
46
+ pred_clip = np.zeros_like(test_clip)
47
+ for i in range(num_embed):
48
+
49
+
50
+ reg = skl.Ridge(alpha=60000, max_iter=50000, fit_intercept=True)
51
+ reg.fit(train_fmri, train_clip[:,i])
52
+ reg_w[i] = reg.coef_
53
+ reg_b[i] = reg.intercept_
54
+
55
+ pred_test_latent = reg.predict(test_fmri)
56
+ std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
57
+ pred_clip[:,i] = std_norm_test_latent * np.std(train_clip[:,i],axis=0) + np.mean(train_clip[:,i],axis=0)
58
+
59
+ print(i,reg.score(test_fmri,test_clip[:,i]))
60
+
61
+
62
+ np.save('data/predicted_features/subj{:02d}/nsd_clipvision_predtest_nsdgeneral.npy'.format(sub),pred_clip)
63
+
64
+ datadict = {
65
+ 'weight' : reg_w,
66
+ 'bias' : reg_b,
67
+
68
+ }
69
+
70
+ with open('data/regression_weights/subj{:02d}/clipvision_regression_weights.pkl'.format(sub),"wb") as f:
71
+ pickle.dump(datadict,f)
scripts/eval_extract_features.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import h5py
5
+ import scipy.io as spio
6
+ import nibabel as nib
7
+
8
+ import torch
9
+ import torchvision
10
+ import torchvision.models as tvmodels
11
+ import torchvision.transforms as transforms
12
+ from torch.utils.data import DataLoader, Dataset
13
+ import torchvision.transforms as T
14
+ from PIL import Image
15
+ import clip
16
+
17
+ import skimage.io as sio
18
+ from skimage import data, img_as_float
19
+ from skimage.transform import resize as imresize
20
+ from skimage.metrics import structural_similarity as ssim
21
+ import scipy as sp
22
+
23
+ import argparse
24
+ parser = argparse.ArgumentParser(description='Argument Parser')
25
+ parser.add_argument("-sub", "--sub", help="Subject Number", default=1)
26
+ args = parser.parse_args()
27
+ sub = int(args.sub)
28
+ assert sub in [0, 1, 2, 5, 7]
29
+
30
+ images_dir = 'data/nsddata_stimuli/test_images'
31
+ feats_dir = 'data/eval_features/test_images'
32
+
33
+ if sub in [1, 2, 5, 7]:
34
+ feats_dir = f'data/eval_features/subj{sub:02d}'
35
+ images_dir = f'results/versatile_diffusion/subj{sub:02d}'
36
+
37
+ if not os.path.exists(feats_dir):
38
+ os.makedirs(feats_dir)
39
+
40
+ class batch_generator_external_images(Dataset):
41
+ def __init__(self, data_path='', prefix='', net_name='clip'):
42
+ self.data_path = data_path
43
+ self.prefix = prefix
44
+ self.net_name = net_name
45
+
46
+ if self.net_name == 'clip':
47
+ self.normalize = transforms.Normalize(
48
+ mean=[0.48145466, 0.4578275, 0.40821073],
49
+ std=[0.26862954, 0.26130258, 0.27577711]
50
+ )
51
+ else:
52
+ self.normalize = transforms.Normalize(
53
+ mean=[0.485, 0.456, 0.406],
54
+ std=[0.229, 0.224, 0.225]
55
+ )
56
+ self.num_test = 982
57
+
58
+ def __getitem__(self, idx):
59
+ img = Image.open(f'{self.data_path}/{self.prefix}{idx}.png')
60
+ img = T.functional.resize(img, (224, 224))
61
+ img = T.functional.to_tensor(img).float()
62
+ img = self.normalize(img)
63
+ return img
64
+
65
+ def __len__(self):
66
+ return self.num_test
67
+
68
+ # Set device
69
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70
+
71
+ global feat_list
72
+ feat_list = []
73
+
74
+ def fn(module, inputs, outputs):
75
+ feat_list.append(outputs.cpu().numpy())
76
+
77
+ net_list = [
78
+ ('inceptionv3', 'avgpool'),
79
+ ('clip', 'final'),
80
+ ('alexnet', 2),
81
+ ('alexnet', 5),
82
+ ('efficientnet', 'avgpool'),
83
+ ('swav', 'avgpool')
84
+ ]
85
+
86
+ batchsize = 64
87
+
88
+ for (net_name, layer) in net_list:
89
+ feat_list = []
90
+ print(net_name, layer)
91
+
92
+ dataset = batch_generator_external_images(data_path=images_dir, net_name=net_name, prefix='')
93
+ loader = DataLoader(dataset, batchsize, shuffle=False)
94
+
95
+ if net_name == 'inceptionv3':
96
+ net = tvmodels.inception_v3(pretrained=True)
97
+ if layer == 'avgpool':
98
+ net.avgpool.register_forward_hook(fn)
99
+ elif layer == 'lastconv':
100
+ net.Mixed_7c.register_forward_hook(fn)
101
+
102
+ elif net_name == 'alexnet':
103
+ net = tvmodels.alexnet(pretrained=True)
104
+ if layer == 2:
105
+ net.features[4].register_forward_hook(fn)
106
+ elif layer == 5:
107
+ net.features[11].register_forward_hook(fn)
108
+ elif layer == 7:
109
+ net.classifier[5].register_forward_hook(fn)
110
+
111
+ elif net_name == 'clip':
112
+ model, _ = clip.load("ViT-L/14", device=device)
113
+ net = model.visual.to(torch.float32)
114
+ if layer == 7:
115
+ net.transformer.resblocks[7].register_forward_hook(fn)
116
+ elif layer == 12:
117
+ net.transformer.resblocks[12].register_forward_hook(fn)
118
+ elif layer == 'final':
119
+ net.register_forward_hook(fn)
120
+
121
+ elif net_name == 'efficientnet':
122
+ net = tvmodels.efficientnet_b1(weights='IMAGENET1K_V1')
123
+ net.avgpool.register_forward_hook(fn)
124
+
125
+ elif net_name == 'swav':
126
+ net = torch.hub.load('facebookresearch/swav:main', 'resnet50')
127
+ net.avgpool.register_forward_hook(fn)
128
+
129
+ net.eval()
130
+ net = net.to(device)
131
+
132
+ with torch.no_grad():
133
+ for i, x in enumerate(loader):
134
+ print(i * batchsize)
135
+ x = x.to(device)
136
+ _ = net(x)
137
+
138
+ if net_name == 'clip':
139
+ if layer == 7 or layer == 12:
140
+ feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2))
141
+ else:
142
+ feat_list = np.concatenate(feat_list)
143
+ else:
144
+ feat_list = np.concatenate(feat_list)
145
+
146
+ file_name = f'{feats_dir}/{net_name}_{layer}.npy'
147
+ np.save(file_name, feat_list)
scripts/evaluate_reconstruction.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import h5py
5
+ import scipy.io as spio
6
+ import nibabel as nib
7
+ import scipy as sp
8
+ from PIL import Image
9
+
10
+
11
+
12
+ import argparse
13
+ parser = argparse.ArgumentParser(description='Argument Parser')
14
+ parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
15
+ args = parser.parse_args()
16
+ sub=int(args.sub)
17
+ assert sub in [1,2,5,7]
18
+
19
+
20
+ from scipy.stats import pearsonr,binom,linregress
21
+ import numpy as np
22
+ def pairwise_corr_all(ground_truth, predictions):
23
+ r = np.corrcoef(ground_truth, predictions)#cosine_similarity(ground_truth, predictions)#
24
+ r = r[:len(ground_truth), len(ground_truth):] # rows: groundtruth, columns: predicitons
25
+ #print(r.shape)
26
+ # congruent pairs are on diagonal
27
+ congruents = np.diag(r)
28
+ #print(congruents)
29
+
30
+ # for each column (predicition) we should count the number of rows (groundtruth) that the value is lower than the congruent (e.g. success).
31
+ success = r < congruents
32
+ success_cnt = np.sum(success, 0)
33
+
34
+ # note: diagonal of 'success' is always zero so we can discard it. That's why we divide by len-1
35
+ perf = np.mean(success_cnt) / (len(ground_truth)-1)
36
+ p = 1 - binom.cdf(perf*len(ground_truth)*(len(ground_truth)-1), len(ground_truth)*(len(ground_truth)-1), 0.5)
37
+
38
+ return perf, p
39
+
40
+
41
+ net_list = [
42
+ ('inceptionv3','avgpool'),
43
+ ('clip','final'),
44
+ ('alexnet',2),
45
+ ('alexnet',5),
46
+ ('efficientnet','avgpool'),
47
+ ('swav','avgpool')
48
+ ]
49
+
50
+ feats_dir = 'data/eval_features/subj{:02d}'.format(sub)
51
+ test_dir = 'data/eval_features/test_images'
52
+ num_test = 982
53
+ distance_fn = sp.spatial.distance.correlation
54
+ pairwise_corrs = []
55
+ for (net_name,layer) in net_list:
56
+ file_name = '{}/{}_{}.npy'.format(test_dir,net_name,layer)
57
+ gt_feat = np.load(file_name)
58
+
59
+ file_name = '{}/{}_{}.npy'.format(feats_dir,net_name,layer)
60
+ eval_feat = np.load(file_name)
61
+
62
+ gt_feat = gt_feat.reshape((len(gt_feat),-1))
63
+ eval_feat = eval_feat.reshape((len(eval_feat),-1))
64
+
65
+ print(net_name,layer)
66
+ if net_name in ['efficientnet','swav']:
67
+ print('distance: ',np.array([distance_fn(gt_feat[i],eval_feat[i]) for i in range(num_test)]).mean())
68
+ else:
69
+ pairwise_corrs.append(pairwise_corr_all(gt_feat[:num_test],eval_feat[:num_test])[0])
70
+ print('pairwise corr: ',pairwise_corrs[-1])
71
+
72
+ from skimage.color import rgb2gray
73
+ from skimage.metrics import structural_similarity as ssim
74
+
75
+ ssim_list = []
76
+ pixcorr_list = []
77
+ for i in range(982):
78
+ gen_image = Image.open('results/versatile_diffusion/subj{:02d}/{}.png'.format(sub,i)).resize((425,425))
79
+ gt_image = Image.open('data/nsddata_stimuli/test_images/{}.png'.format(i))
80
+ gen_image = np.array(gen_image)/255.0
81
+ gt_image = np.array(gt_image)/255.0
82
+ pixcorr_res = np.corrcoef(gt_image.reshape(1,-1), gen_image.reshape(1,-1))[0,1]
83
+ pixcorr_list.append(pixcorr_res)
84
+ gen_image = rgb2gray(gen_image)
85
+ gt_image = rgb2gray(gt_image)
86
+ ssim_res = ssim(gen_image, gt_image, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)
87
+ ssim_list.append(ssim_res)
88
+
89
+ ssim_list = np.array(ssim_list)
90
+ pixcorr_list = np.array(pixcorr_list)
91
+ print('PixCorr: {}'.format(pixcorr_list.mean()))
92
+ print('SSIM: {}'.format(ssim_list.mean()))
93
+