File size: 3,532 Bytes
197d4ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import torch
#from models import vision_transformer as vit
#from models import vision_transformer_multiBlocks_20221030 as vit
#from methods import vision_transformer_multiBlocks_20221030 as vit
from methods import ViT as vit
#import vision_transformer_multiBlocks_20221030 as vit
#from models.pmf_protonet import ProtoNet
#from methods.pmf_protonet import ProtoNet
from methods.protonet import ProtoNet
#from pmf_protonet import ProtoNet
#from models.cvpr2023_gnnnet_20221102 import GnnNet
#from methods.cvpr2023_gnnnet_20221102 import GnnNet
#from cvpr2023_gnnnet_20221102 import GnnNet
def load_ViTsmall(no_pretrain=False):
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
if(not no_pretrain):
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
#print('Pretrained weights found at {}'.format(url))
#print('model defined.')
return model
def load_ViTbase(no_pretrain=False):
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
if(not no_pretrain):
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
print('model defined.')
return model
def load_ResNet50(no_pretrain=False):
from torchvision.models.resnet import resnet50
pretrained = not no_pretrain
model = resnet50(pretrained=pretrained)
model.fc = torch.nn.Identity()
print('model defined.')
return model
def load_ResNet50_dino(no_pretrain=False):
from torchvision.models.resnet import resnet50
model = resnet50(pretrained=False)
model.fc = torch.nn.Identity()
if not no_pretrain:
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",map_location="cpu",)
model.load_state_dict(state_dict, strict=False)
return model
def load_ResNet50_clip(no_pretrain=False):
from models import clip
model, _ = clip.load('RN50', 'cpu')
return model
def get_model(backbone='vit_small', classifier='protonet', args=None, styleAdv=False):
if(backbone=='vit_small' and classifier == 'protonet'):
extractor = load_ViTsmall()
if(not styleAdv):
#from models.pmf_protonet import ProtoNet
from methods.protonet import ProtoNet
model = ProtoNet(extractor)
else:
#from models.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet
#from methods.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet
from methods.StyleAdv_ViT_protonet import ProtoNet
model = ProtoNet(extractor)
if(backbone=='resnet50' and classifier == 'protonet'):
extractor = load_ResNet50_dino()
model = ProtoNet(extractor)
if(backbone=='vit_small' and classifier == 'gnnnet'):
extractor = load_ViTsmall()
model = GnnNet(extractor, backbone_flag='vit_small', n_way = 5, n_support = args.nSupport)
if(backbone=='resnet50' and classifier == 'gnnnet'):
extractor = load_ResNet50_dino()
model = GnnNet(extractor, backbone_flag='resnet50', n_way = 5, n_support = args.nSupport)
return model
if __name__ == '__main__':
input = torch.randn(16, 3, 224, 224)
print('input:', input.size())
model = load_ViTsmall()
out = model(input)
print('out:', out.size())
|