import torch import torch.nn as nn import copy from .vit_inflora import VisionTransformer, PatchEmbed, Block, resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn class ViT_lora_co(VisionTransformer): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64): super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values, embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn, n_tasks=n_tasks, rank=rank) def forward(self, x, task_id, register_blk=-1, get_feat=False, get_cur_feat=False): x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed[:, :x.size(1), :] x = self.pos_drop(x) prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) for i, blk in enumerate(self.blocks): x = blk(x, task_id, register_blk == i, get_feat=get_feat, get_cur_feat=get_cur_feat) x = self.norm(x) return x, prompt_loss def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError( 'features_only not implemented for Vision Transformer models.') # NOTE this extra code to support handling of repr size for in21k pretrained models # pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant) default_num_classes = pretrained_cfg['num_classes'] num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: repr_size = None model = build_model_with_cfg( ViT_lora_co, variant, pretrained, pretrained_cfg=pretrained_cfg, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) return model class SiNet_vit(nn.Module): def __init__(self, **args): ''' args is a dictionary with the required arguments. image_encoder is defined in vit_inflora. class_num is the number of initial class. ''' super(SiNet_vit, self).__init__() model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"]) self.image_encoder = _create_vision_transformer( 'vit_base_patch16_224_in21k', pretrained=True, **model_kwargs) self.class_num = 1 self.class_num = args["init_cls"] self.classifier_pool = nn.ModuleList([ nn.Linear(args["embd_dim"], self.class_num, bias=True) for i in range(args["total_sessions"]) ]) self.classifier_pool_backup = nn.ModuleList([ nn.Linear(args["embd_dim"], self.class_num, bias=True) for i in range(args["total_sessions"]) ]) self.numtask = 0 @property def feature_dim(self): return self.image_encoder.out_dim def extract_vector(self, image, task=None): if task == None: image_features, _ = self.image_encoder(image, self.numtask-1) else: image_features, _ = self.image_encoder(image, task) image_features = image_features[:, 0, :] return image_features def forward(self, image, get_feat=False, get_cur_feat=False, fc_only=False): """ return the output of fully connected layer. """ if fc_only: fc_outs = [] for ti in range(self.numtask): fc_out = self.classifier_pool[ti](image) fc_outs.append(fc_out) return torch.cat(fc_outs, dim=1) logits = [] image_features, prompt_loss = self.image_encoder( image, task_id=self.numtask-1, get_feat=get_feat, get_cur_feat=get_cur_feat) image_features = image_features[:, 0, :] image_features = image_features.view(image_features.size(0), -1) for prompts in [self.classifier_pool[self.numtask-1]]: logits.append(prompts(image_features)) return { 'logits': torch.cat(logits, dim=1), 'features': image_features, 'prompt_loss': prompt_loss } def interface(self, image): image_features, _ = self.image_encoder(image, task_id=self.numtask-1) image_features = image_features[:, 0, :] image_features = image_features.view(image_features.size(0), -1) logits = [] for prompt in self.classifier_pool[:self.numtask]: logits.append(prompt(image_features)) logits = torch.cat(logits, 1) return logits def update_fc(self, nb_classes): """ update the number of tasks. """ self.numtask += 1 def classifier_backup(self, task_id): self.classifier_pool_backup[task_id].load_state_dict( self.classifier_pool[task_id].state_dict()) def classifier_recall(self): self.classifier_pool.load_state_dict(self.old_state_dict) def copy(self): return copy.deepcopy(self) def freeze(self): for param in self.parameters(): param.requires_grad = False self.eval() return self