import math from abc import ABC class ParameterSchedule(ABC): def compute(self, t): raise NotImplementedError class ExponentialInterpolation(ParameterSchedule): def __init__(self, start, end, alpha): self.start = start self.end = end self.alpha = alpha def compute(self, t): if self.alpha != 0: return self.start + (self.end - self.start) * ( math.exp(self.alpha * t) - 1 ) / (math.exp(self.alpha) - 1) else: return self.start + (self.end - self.start) * t class PiecewiseStepFunction(ParameterSchedule): def __init__(self, thresholds, values): self.thresholds = thresholds self.values = values def compute(self, t): assert len(self.thresholds) > 0 assert len(self.values) == len(self.thresholds) + 1 idx = 0 while idx < len(self.thresholds) and t > self.thresholds[idx]: idx += 1 return self.values[idx]