File size: 3,532 Bytes
197d4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
#from models import vision_transformer as vit
#from models import vision_transformer_multiBlocks_20221030 as vit
#from methods import vision_transformer_multiBlocks_20221030 as vit
from methods import ViT as vit
#import vision_transformer_multiBlocks_20221030 as vit
#from models.pmf_protonet import ProtoNet
#from methods.pmf_protonet import ProtoNet
from methods.protonet import ProtoNet
#from pmf_protonet import ProtoNet
#from models.cvpr2023_gnnnet_20221102 import GnnNet
#from methods.cvpr2023_gnnnet_20221102 import GnnNet
#from cvpr2023_gnnnet_20221102 import GnnNet

def load_ViTsmall(no_pretrain=False):
  model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
  if(not no_pretrain):
    url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
    state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
    model.load_state_dict(state_dict, strict=True)
    #print('Pretrained weights found at {}'.format(url))
  #print('model defined.')
  return model

def load_ViTbase(no_pretrain=False):
  model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
  if(not no_pretrain):
    url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
    state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
    model.load_state_dict(state_dict, strict=True)
    print('Pretrained weights found at {}'.format(url))
  print('model defined.')
  return model


def load_ResNet50(no_pretrain=False):
  from torchvision.models.resnet import resnet50
  pretrained = not no_pretrain
  model = resnet50(pretrained=pretrained)
  model.fc = torch.nn.Identity()
  print('model defined.')
  return model

def load_ResNet50_dino(no_pretrain=False):
  from torchvision.models.resnet import resnet50
  model = resnet50(pretrained=False)
  model.fc = torch.nn.Identity()
  if not no_pretrain:
    state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",map_location="cpu",)
    model.load_state_dict(state_dict, strict=False)
  return model

def load_ResNet50_clip(no_pretrain=False):
  from models import clip
  model, _ = clip.load('RN50', 'cpu')
  return model


def get_model(backbone='vit_small', classifier='protonet', args=None, styleAdv=False):
  if(backbone=='vit_small' and classifier == 'protonet'):
    extractor = load_ViTsmall()
    if(not styleAdv):
      #from models.pmf_protonet import ProtoNet
      from methods.protonet import ProtoNet
      model = ProtoNet(extractor)
    else:
      #from models.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet
      #from methods.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet
      from methods.StyleAdv_ViT_protonet import ProtoNet
      model = ProtoNet(extractor)

  if(backbone=='resnet50' and classifier == 'protonet'):
    extractor = load_ResNet50_dino()
    model = ProtoNet(extractor)

  if(backbone=='vit_small' and classifier == 'gnnnet'):
    extractor = load_ViTsmall()
    model = GnnNet(extractor, backbone_flag='vit_small', n_way = 5, n_support = args.nSupport)

  if(backbone=='resnet50' and classifier == 'gnnnet'):
    extractor = load_ResNet50_dino()
    model = GnnNet(extractor, backbone_flag='resnet50', n_way = 5, n_support = args.nSupport)
  return model




if __name__ == '__main__':
  input = torch.randn(16, 3, 224, 224)
  print('input:', input.size())
  model = load_ViTsmall()
  out = model(input)
  print('out:', out.size())