LibContinual / core /model /codaprompt.py
boringKey's picture
Upload 236 files
5fee096 verified
# -*- coding: utf-8 -*-
"""
@inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23,
author = {James Seale Smith and
Leonid Karlinsky and
Vyshnavi Gutta and
Paola Cascante{-}Bonilla and
Donghyun Kim and
Assaf Arbelle and
Rameswar Panda and
Rog{\'{e}}rio Feris and
Zsolt Kira},
title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free
Continual Learning},
booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
{CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023},
pages = {11909--11919},
publisher = {{IEEE}},
year = {2023}
}
https://arxiv.org/abs/2211.13218
Adapted from https://github.com/GT-RIPL/CODA-Prompt
"""
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader
class Model(nn.Module):
# A model consists with a backbone and a classifier
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
def forward(self, x, train=True):
if train:
feat, loss = self.backbone(x, train=True)
return self.classifier(feat), loss
else:
feat = self.backbone(x, train=False)
return self.classifier(feat)
class CodaPrompt(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.kwargs = kwargs
self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
self.network.backbone.create_prompt('coda', n_tasks = kwargs['task_num'], prompt_param=[kwargs['pool_size'], kwargs['prompt_length'], kwargs['mu']])
self.task_idx = 0
self.kwargs = kwargs
self.last_out_dim = 0
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.task_idx = task_idx
self.network.backbone.task_id = task_idx
in_features = self.network.classifier.in_features
out_features = self.network.classifier.out_features
new_out_features = self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num']
new_fc = nn.Linear(in_features, new_out_features)
new_fc.weight.data[:out_features] = self.network.classifier.weight.data
new_fc.bias.data[:out_features] = self.network.classifier.bias.data
self.network.classifier = new_fc
self.network.to(self.device)
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
self.out_dim = new_out_features
self.dw_k = torch.tensor(np.ones(self.out_dim + 1, dtype=np.float32)).to(self.device)
def observe(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit, loss = self.network(x, train=True)
logit[:,:self.last_out_dim] = -float('inf')
dw_cls = self.dw_k[-1 * torch.ones(y.size()).long()]
loss += (self.loss_fn(logit, y) * dw_cls).mean()
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self.last_out_dim = self.out_dim
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.network(x, train=False)
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def get_parameters(self, config):
return list(self.network.backbone.prompt.parameters()) + list(self.network.classifier.parameters())