Boltz2 / vb_potentials_schedules.py
lhallee's picture
Upload folder using huggingface_hub
827d9ec verified
import math
from abc import ABC
class ParameterSchedule(ABC):
def compute(self, t):
raise NotImplementedError
class ExponentialInterpolation(ParameterSchedule):
def __init__(self, start, end, alpha):
self.start = start
self.end = end
self.alpha = alpha
def compute(self, t):
if self.alpha != 0:
return self.start + (self.end - self.start) * (
math.exp(self.alpha * t) - 1
) / (math.exp(self.alpha) - 1)
else:
return self.start + (self.end - self.start) * t
class PiecewiseStepFunction(ParameterSchedule):
def __init__(self, thresholds, values):
self.thresholds = thresholds
self.values = values
def compute(self, t):
assert len(self.thresholds) > 0
assert len(self.values) == len(self.thresholds) + 1
idx = 0
while idx < len(self.thresholds) and t > self.thresholds[idx]:
idx += 1
return self.values[idx]