| """ |
| Main model for using CodecLM. This will combine all the required components |
| and provide easy access to the generation API. |
| """ |
|
|
| import typing as tp |
| import warnings |
| import sys |
| import time |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| import torchaudio |
| import numpy as np |
| import lightning as pl |
| from torchmetrics.classification import MulticlassAccuracy |
| import pdb |
| from codeclm.models import builders |
| import math |
| from torch.optim import Optimizer |
| from torch.optim.lr_scheduler import _LRScheduler |
| from peft import LoraConfig, get_peft_model |
| from datetime import datetime |
| import os |
| os.environ['TOKENIZERS_PARALLELISM'] = "false" |
|
|
|
|
| class CodecLM_PL(pl.LightningModule): |
| def __init__(self, cfg, ckpt_path): |
| super().__init__() |
|
|
| self.cfg = cfg |
| |
| |
| self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) |
| if self.audio_tokenizer is not None: |
| for param in self.audio_tokenizer.parameters(): |
| param.requires_grad = False |
| if "audio_tokenizer_checkpoint_sep" in self.cfg.keys(): |
| self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) |
| for param in self.seperate_tokenizer.parameters(): |
| param.requires_grad = False |
| else: |
| self.seperate_tokenizer = None |
| |
| |
| self.audiolm = builders.get_lm_model(self.cfg) |
| print(self.audiolm) |
| |
| checkpoint = torch.load(ckpt_path, map_location='cpu') |
| missing, unexpected = self.load_state_dict(checkpoint, strict=False) |
| print("successfully load pretrained model {}".format(ckpt_path)) |
| |
| self.val_steps = [] |
| self.train_slide_acc = [] |
| self.train_steps = [] |
| self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( |
| self.audiolm.code_size, |
| top_k=1, |
| average="micro", multidim_average="global", |
| ignore_index=self.cfg.lm.code_size, |
| ) for _ in range(self.audiolm.code_depth)]) |
| self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( |
| self.audiolm.code_size, |
| top_k=10, |
| average="micro", multidim_average="global", |
| ignore_index=self.cfg.lm.code_size, |
| ) for _ in range(self.audiolm.code_depth)]) |
|
|
| self.epoch = 0 |
|
|
| |
| def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): |
| batch_size = sequence_lengths.size(0) |
| max_length = x.size(2) |
|
|
| |
| if max_length == sequence_lengths.max(): |
| x = F.pad(x, (0, 1), value=end_id) |
| max_length = x.size(2) |
|
|
| if max_length <= sequence_lengths.max() + 1: |
| sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) |
|
|
| |
| x[torch.arange(batch_size), :, sequence_lengths] = end_id |
| sequence_lengths += 1 |
|
|
| mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) |
| mask = mask.to(x.device) |
| mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) |
| x = torch.where(mask_3d, x, end_id+1) |
| return x, mask_3d |
|
|
| def get_time(self): |
| |
| now = datetime.now() |
|
|
| |
| formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") |
| return formatted_now |
|
|
| class CosineLRScheduler(_LRScheduler): |
| """Cosine LR scheduler. |
| |
| Args: |
| optimizer (Optimizer): Torch optimizer. |
| warmup_steps (int): Number of warmup steps. |
| total_steps (int): Total number of steps. |
| lr_min_ratio (float): Minimum learning rate. |
| cycle_length (float): Cycle length. |
| """ |
| def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, |
| lr_min_ratio: float = 0.0, cycle_length: float = 1.0): |
| self.warmup_steps = warmup_steps |
| assert self.warmup_steps >= 0 |
| self.total_steps = total_steps |
| assert self.total_steps >= 0 |
| self.lr_min_ratio = lr_min_ratio |
| self.cycle_length = cycle_length |
| super().__init__(optimizer) |
|
|
| def _get_sched_lr(self, lr: float, step: int): |
| if step < self.warmup_steps: |
| lr_ratio = step / self.warmup_steps |
| lr = lr_ratio * lr |
| elif step <= self.total_steps: |
| s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
| lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ |
| (1. + math.cos(math.pi * s / self.cycle_length)) |
| lr = lr_ratio * lr |
| else: |
| lr_ratio = self.lr_min_ratio |
| lr = lr_ratio * lr |
| return lr |
|
|
| def get_lr(self): |
| return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] |
|
|