| import math | |
| from abc import ABC | |
| class ParameterSchedule(ABC): | |
| def compute(self, t): | |
| raise NotImplementedError | |
| class ExponentialInterpolation(ParameterSchedule): | |
| def __init__(self, start, end, alpha): | |
| self.start = start | |
| self.end = end | |
| self.alpha = alpha | |
| def compute(self, t): | |
| if self.alpha != 0: | |
| return self.start + (self.end - self.start) * ( | |
| math.exp(self.alpha * t) - 1 | |
| ) / (math.exp(self.alpha) - 1) | |
| else: | |
| return self.start + (self.end - self.start) * t | |
| class PiecewiseStepFunction(ParameterSchedule): | |
| def __init__(self, thresholds, values): | |
| self.thresholds = thresholds | |
| self.values = values | |
| def compute(self, t): | |
| assert len(self.thresholds) > 0 | |
| assert len(self.values) == len(self.thresholds) + 1 | |
| idx = 0 | |
| while idx < len(self.thresholds) and t > self.thresholds[idx]: | |
| idx += 1 | |
| return self.values[idx] | |