| import abc |
| from omegaconf import DictConfig |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
|
|
| def get_schedule_from_config(config: DictConfig): |
| match config.type: |
| case "geometric": |
| return GeometricSchedule(min_val=config.min, max_val=config.max) |
| case "linear": |
| return LinearSchedule() |
| case "sin": |
| return SinSchedule() |
| case "cosine": |
| return CosineSchedule() |
| case "polynomial": |
| return PolynomialSchedule(exp=config.exp) |
| case _: |
| raise ValueError(f"Invalid schedule type: {config.type}") |
|
|
|
|
| class Schedule(abc.ABC): |
| """ |
| Generic schedule class for masking or noising |
| This represents function a : [0, 1] -> [0, 1] satisfying a(0) = 0, a(1) = 1 or at least approximately |
| """ |
|
|
| @abc.abstractmethod |
| def at(self, t: Tensor): |
| """ |
| Return value a(t) |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def derivative_at(self, t: Tensor): |
| """ |
| Return d/dt a(t) |
| """ |
| raise NotImplementedError |
|
|
| def rate_scale_factor(self, t: Tensor) -> Tensor: |
| """ |
| Return d/dt a(t) / (1 - a(t)) common in rate matrix calculation |
| """ |
| return self.derivative_at(t) / (1 - self.at(t)) |
|
|
| def sample(self, shape, device) -> Tensor: |
| """ |
| Sample from the schedule, returns a tensor of shape `shape` with values in [0, 1] |
| """ |
| uniform = torch.rand(shape, device=device) |
| return self.inv(uniform) |
|
|
| def sample_truncated(self, threshold, shape, device) -> Tensor: |
| """ |
| Sample from a truncated schedule, returns a tensor of shape `shape` with values in [threshold, 1] |
| """ |
| uniform = torch.rand(shape, device=device) |
| threshold = self.at(threshold) |
| return self.inv(uniform * (1 - threshold) + threshold) |
|
|
| @abc.abstractmethod |
| def inv(self, alpha: Tensor): |
| """ |
| Given alpha in [0, 1] such that a(t)=alpha, returns the corresponding t. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class LinearSchedule(Schedule): |
| def __init__(self): |
| pass |
|
|
| def at(self, t: Tensor): |
| return t |
|
|
| def derivative_at(self, t: Tensor): |
| return torch.ones_like(t, device=t.device) |
|
|
| def inv(self, alpha: Tensor): |
| return alpha |
|
|
|
|
| class GeometricSchedule(Schedule, nn.Module): |
| def __init__(self, min_val: float, max_val: float): |
| super().__init__() |
| self.register_buffer("min", Tensor([min_val])) |
| self.register_buffer("max", Tensor([max_val])) |
|
|
| def at(self, t: Tensor): |
| min_val = self.min.to(t.device) |
| max_val = self.max.to(t.device) |
| return torch.exp(-(min_val ** (1 - t)) * max_val**t) |
|
|
| def derivative_at(self, t): |
| min_val = self.min.to(t.device) |
| max_val = self.max.to(t.device) |
| return ( |
| self.at(t) |
| * min_val ** (1 - t) |
| * max_val**t |
| * (min_val.log() - max_val.log()) |
| ) |
|
|
| def inv(self, alpha: Tensor): |
| log_min = self.min.to(alpha.device).log() |
| log_max = self.max.to(alpha.device).log() |
| return (torch.log(-torch.log(alpha)) - log_min) / (log_max - log_min) |
|
|
|
|
| class SinSchedule(Schedule, nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def at(self, t: Tensor): |
| return torch.sin(torch.pi / 2 * t) |
|
|
| def derivative_at(self, t: Tensor): |
| return (torch.pi / 2) * torch.cos(torch.pi / 2 * t) |
|
|
| def inv(self, alpha: Tensor): |
| return (2 / torch.pi) * torch.asin(alpha.clamp(min=0., max=1.)) |
|
|
|
|
| class CosineSchedule(Schedule, nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def at(self, t: Tensor): |
| return 1 - torch.cos(torch.pi / 2 * t) |
| |
| def derivative_at(self, t: Tensor): |
| return (torch.pi / 2) * torch.sin(torch.pi / 2 * t) |
| |
| def rate_scale_factor(self, t): |
| return (torch.pi/2) * torch.tan(torch.pi / 2 * t) |
| |
| def inv(self, alpha): |
| return (2 / torch.pi) * torch.arccos(1 - alpha.clamp(min=0., max=1.)) |
|
|
| class PolynomialSchedule(Schedule, nn.Module): |
| def __init__(self, exp): |
| super().__init__() |
| self.exp = exp |
| |
| def at(self, t: Tensor): |
| return t ** self.exp |
| |
| def derivative_at(self, t: Tensor): |
| return self.exp * t ** (self.exp - 1) |
| |
| def inv(self, alpha: Tensor): |
| return alpha ** (1 / self.exp) |