LibContinual / docs /tutorials /zh /add_a_new_method.md
boringKey's picture
Upload 236 files
5fee096 verified

Add a new method

下面以LUCIR方法为例,描述如何添加一种新的方 法。

首先,所有方法都继承同一父类Finetune

class Finetune(nn.Module):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        ...
        self.kwargs = kwargs
    
    def observe(self, data):
        ...
        return pred, acc / x.size(0), loss

    def inference(self, data):
        ...
        return pred, acc / x.size(0)

    def forward(self, x):
        ...

    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        pass

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        pass
    
    def get_parameters(self, config):
        ...
        return train_parameters

Finetune类包含了一个方法需要具备的几个重要接口:

  • __init__:初始化函数,用于初始化各方法需要的参数。
  • observe:用于训练阶段调用,输入一个batch的训练样本,返回预测、准确率以及前向损失。
  • inference:用于推理阶段调用,输入一个batch的样本,返回分类输出、准确率。
  • forward:重写pytorchModule中的forward函数,返回backbone的输出。
  • before_task:在每个任务开始训练前调用,用于对模型结构、训练参数等进行调整,需要用户自定义。
  • after_task:在每个任务开始训练后调用,用于对模型结构、训练参数等进行调整,需要用户自定义。
  • get_parameters:在每个任务开始训练前调用,返回当前任务的训练参数。

LUCIR

建立模型

首先在core/model/replay下添加lucir.py文件:(此处省略部分源码)

class LUCIR(Finetune):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        super().__init__(backbone, feat_dim, num_class, **kwargs)
        self.kwargs = kwargs
        self.K = kwargs['K']
        self.lw_mr = kwargs['lw_mr']
        self.ref_model = None


    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        self.task_idx = task_idx

        self.ref_model = copy.deepcopy(self.backbone)
        ...
        new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num'])

        self.loss_fn1 = nn.CosineEmbeddingLoss()
        self.loss_fn2 = nn.CrossEntropyLoss()
        self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist'])
        ...

        self.backbone = self.backbone.to(self.device)
        if self.ref_model is not None:
            self.ref_model = self.ref_model.to(self.device)


    def _init_new_fc(self, task_idx, buffer, train_loader):
        if task_idx == 0:
            return
        ...
        self.backbone.fc.fc2.weight.data = novel_embedding.to(self.device)

    def _compute_feature(self, feature_model, loader, num_samples, num_features):
        ...


    def observe(self, data):
        x, y = data['image'], data['label']
        logit = self.backbone(x)

        ...
        ref_outputs = self.ref_model(x)
        loss = self.loss_fn1(...) * self.cur_lamda
        loss += self.loss_fn2(...)
        if  hard_num > 0:
            ...
            loss += self.loss_fn3(...) * self.lw_mr

        pred = torch.argmax(logit, dim=1)

        acc = torch.sum(pred == y).item()
        return pred, acc / x.size(0), loss

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        if self.task_idx > 0:
            self.handle_ref_features.remove()
            ...


    def inference(self, data):
        pass


    def _init_optim(self, config, task_idx):
        ...
        tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \
                        {'params': self.backbone.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0}]
        return tg_params
  • __init__中,对LUCIR所需要的参数K, lw_mr, ref_model进行初始化。
  • before_task中,根据LUCIR的需要,我们在任务开始前对新旧分类头进行更新,并根据task_idx设置不同的损失函数 。
  • observe中,我们实现了训练阶段中LUCIR的训练算法,根据task_idx采用不同的训练方法对模型进行训练。
  • after_task中,根据LUCIR算法需要移除一些hook操作。
  • _init_optim中,我们完成了对于训练参数的选择。

以上几个接口的实现是LUCIR算法与其他算法的不同点,其他接口无特殊处理可以不实现交由Finetune实现
注意,由于持续学习算法对于第一个任务和其他任务有不同的操作,在before_task会传入task_idx来标识当前是第几个任务。

新增lucir.yaml文件

各参数含义请参考'config.md'

数据划分相关参数

data_root: /data/fanzhichen/continual/cifar100
image_size: 32
save_path: ./
init_cls_num: 50
inc_cls_num: 10
task_num: 6

训练优化器相关参数

optimizer:
  name: SGD
  kwargs:
    lr: 0.1
    momentum: 0.9
    weight_decay: 0.0005

lr_scheduler:
  name: MultiStepLR
  kwargs:
    gamma: 0.1
    milestones: [80, 120]

backbone相关参数

backbone:
  name: resnet32
  kwargs:
    num_classes: 100
    args: 
      dataset: cifar100
      cosine_fc: True

buffer相关参数

name: 选择LinearBuffer, 会将数据在任务开始前与当前任务数据合并在一起。
strategy:选择herding更新策略,目前可支持random,equal_random,reservoir,herding,None

buffer:
  name: LinearBuffer
  kwargs:
    buffer_size: 2000
    batch_size: 128
    strategy: herding     # random, equal_random, reservoir, herding

算法相关参数

name:此处标识所采用何种算法

classifier:
  name: LUCIR
  kwargs:
    num_class: 100
    feat_dim: 512
    init_cls_num: 50
    inc_cls_num: 10
    dist: 0.5
    lamda: 5
    K: 2
    lw_mr: 1