| | |
| | """ |
| | @inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23, |
| | author = {James Seale Smith and |
| | Leonid Karlinsky and |
| | Vyshnavi Gutta and |
| | Paola Cascante{-}Bonilla and |
| | Donghyun Kim and |
| | Assaf Arbelle and |
| | Rameswar Panda and |
| | Rog{\'{e}}rio Feris and |
| | Zsolt Kira}, |
| | title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free |
| | Continual Learning}, |
| | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, |
| | {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023}, |
| | pages = {11909--11919}, |
| | publisher = {{IEEE}}, |
| | year = {2023} |
| | } |
| | |
| | https://arxiv.org/abs/2211.13218 |
| | |
| | Adapted from https://github.com/GT-RIPL/CODA-Prompt |
| | """ |
| |
|
| | import math |
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import Parameter |
| | import torch.nn.functional as F |
| | from .finetune import Finetune |
| | from core.model.backbone.resnet import * |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | class Model(nn.Module): |
| | |
| | def __init__(self, backbone, feat_dim, num_class): |
| | super().__init__() |
| | self.backbone = backbone |
| | self.feat_dim = feat_dim |
| | self.num_class = num_class |
| | self.classifier = nn.Linear(feat_dim, num_class) |
| | |
| | def forward(self, x, train=True): |
| | if train: |
| | feat, loss = self.backbone(x, train=True) |
| | return self.classifier(feat), loss |
| | else: |
| | feat = self.backbone(x, train=False) |
| | return self.classifier(feat) |
| |
|
| |
|
| | class CodaPrompt(Finetune): |
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__(backbone, feat_dim, num_class, **kwargs) |
| | self.kwargs = kwargs |
| | self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) |
| | self.network.backbone.create_prompt('coda', n_tasks = kwargs['task_num'], prompt_param=[kwargs['pool_size'], kwargs['prompt_length'], kwargs['mu']]) |
| | self.task_idx = 0 |
| | self.kwargs = kwargs |
| | |
| | self.last_out_dim = 0 |
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | self.task_idx = task_idx |
| | self.network.backbone.task_id = task_idx |
| | |
| | in_features = self.network.classifier.in_features |
| | out_features = self.network.classifier.out_features |
| | new_out_features = self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'] |
| | new_fc = nn.Linear(in_features, new_out_features) |
| | new_fc.weight.data[:out_features] = self.network.classifier.weight.data |
| | new_fc.bias.data[:out_features] = self.network.classifier.bias.data |
| | self.network.classifier = new_fc |
| | self.network.to(self.device) |
| |
|
| | self.loss_fn = nn.CrossEntropyLoss(reduction='none') |
| | |
| | self.out_dim = new_out_features |
| | self.dw_k = torch.tensor(np.ones(self.out_dim + 1, dtype=np.float32)).to(self.device) |
| |
|
| | def observe(self, data): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | logit, loss = self.network(x, train=True) |
| |
|
| | logit[:,:self.last_out_dim] = -float('inf') |
| | dw_cls = self.dw_k[-1 * torch.ones(y.size()).long()] |
| |
|
| | loss += (self.loss_fn(logit, y) * dw_cls).mean() |
| | |
| | pred = torch.argmax(logit, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0), loss |
| | |
| | |
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | self.last_out_dim = self.out_dim |
| |
|
| | def inference(self, data): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | |
| | logit = self.network(x, train=False) |
| |
|
| | pred = torch.argmax(logit, dim=1) |
| |
|
| | acc = torch.sum(pred == y).item() |
| | return pred, acc / x.size(0) |
| |
|
| |
|
| | def get_parameters(self, config): |
| | return list(self.network.backbone.prompt.parameters()) + list(self.network.classifier.parameters()) |
| |
|