MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
# adopted from https://www.kaggle.com/code/snnclsr/learning-rate-schedulers
# adopted from https://gist.github.com/davidgilbertson/2a6ac54ad6629a37e8f4d0539f7ef7bc
import timeit
import math
from typing import Sequence, Mapping, Literal, Callable
import torch
import transformers
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import AutoModel
class KeyframeLR(LRScheduler):
def __init__(
self,
optimizer: Optimizer,
frames,
end: float,
units: Literal["percent", "steps", "time"] = "percent",
):
"""
Define a PyTorch LR scheduler with keyframes
Parameters
----------
optimizer
torch.optim optimizer
frames
A sequence of mappings (e.g. list of dicts), each one either specifying a
position/lr or transition.
Positions should be defined like `{"position": 0.2, "lr": 0.1}`.
As a shorthand, you can also provide a list or tuple with the position/lr
When units are `"steps"`, define the position in steps, else define the position as
a float in the interval [0, 1].
Transitions can optionally be inserted between positions, e.g. `{"transform": "cos"}`
If no transition is defined between two positions, `linear` will be used.
Options are `"linear"` and `"cos"`, or a function with the signature:
`func(last_lr, start_frame, end_frame, position, scheduler)`
As a shorthand, you can also provide just the string or callable
end
When `units` are `"time"`, this should be the expected run-time in seconds
Otherwise, this should be the maximum number of times you plan to call .step()
units
"percent", "steps", or "time". Default is "percent"
"""
self.end = end
self.units = units
self.frames = self.parse_frames(frames)
self.last_lr = 0
self.start_time = timeit.default_timer() if units == "time" else None
super().__init__(optimizer=optimizer)
def parse_frames(self, user_frames):
frames = []
previous_pos = -1
end_pos = self.end if self.units == "steps" else 1
unpacked_frames = []
for frame in user_frames:
# Allow shorthand for position
if isinstance(frame, Sequence) and len(frame) == 2:
frame = {"position": frame[0], "lr": frame[1]}
# Allow shorthand for transition
if isinstance(frame, (str, Callable)):
frame = {"transition": frame}
# Allow for "position": "end"
if frame.get("position", None) == "end":
frame["position"] = end_pos
unpacked_frames.append(frame)
for i, frame in enumerate(unpacked_frames):
first_frame = i == 0
last_frame = i == len(unpacked_frames) - 1
if first_frame:
if "position" in frame and frame["position"] != 0:
frames.append({"position": 0, "lr": 0})
frames.append({"transition": "linear"})
if "transition" in frame:
frames.append({"position": 0, "lr": 0})
frames.append(frame)
if "position" in frame:
position = frame["position"]
assert (
position >= previous_pos
), f"position {position!r} is not bigger than {previous_pos}"
assert (
position <= end_pos
), f"position {position} is bigger than end value {end_pos}"
previous_pos = position
if not last_frame:
next_frame = unpacked_frames[i + 1]
if "position" in next_frame:
frames.append({"transition": "linear"})
if last_frame:
if "position" in frame and frame["position"] < end_pos:
frames.append({"transition": "linear"})
frames.append({"position": end_pos, "lr": 0})
if "transition" in frame:
frames.append({"position": end_pos, "lr": 0})
return frames
@staticmethod
def interpolate(a, b, pct):
return (1 - pct) * a + pct * b
def interpolate_frames(self, start_frame, transition, end_frame, position):
pos_range = end_frame["position"] - start_frame["position"]
pct_of_range = (position - start_frame["position"]) / pos_range
if transition == "linear":
return self.interpolate(
start_frame["lr"],
end_frame["lr"],
pct_of_range,
)
if transition == "cos":
pct_of_range_cos = 1 - (1 + math.cos(pct_of_range * math.pi)) / 2
return self.interpolate(
start_frame["lr"],
end_frame["lr"],
pct_of_range_cos,
)
if isinstance(transition, Callable):
return transition(self.last_lr, start_frame, end_frame, position, self)
raise ValueError(f"Unknown transition: {transition!r}")
def get_lr_at_pos(self, position):
start_frame = None
transition = None
end_frame = None
lr = None
for frame in self.frames:
if "position" in frame:
if frame["position"] == position:
lr = frame["lr"]
# Direct match, we're done
break
if frame["position"] < position:
start_frame = frame
if start_frame is not None and "transition" in frame:
transition = frame["transition"]
if (
transition is not None
and "position" in frame
and frame["position"] >= position
):
end_frame = frame
break
if lr is None:
if start_frame is None or end_frame is None:
print(f"No matching frames at position {position}, using last LR.")
return self.last_lr
lr = self.interpolate_frames(start_frame, transition, end_frame, position)
# We store last_lr here so that custom transitions work with .sample_lrs()
self.last_lr = lr
return lr
@property
def progress(self):
if self.units == "time":
return (timeit.default_timer() - self.start_time) / self.end
return self.last_epoch / self.end
def get_lr(self):
if self.units == "percent":
position = self.last_epoch / self.end
elif self.units == "steps":
position = self.last_epoch
elif self.units == "time":
position = (timeit.default_timer() - self.start_time) / self.end
else:
raise TypeError(f"Unknown units {self.units}")
lr = self.get_lr_at_pos(position)
return [lr for _ in self.optimizer.param_groups]
def sample_lrs(self, n=100):
"""
Get a sample of the LRs that would be produced, for visualization.
This might not work well with custom transitions.
"""
# We don't want to generate a huge number of steps or affect optimizer state
# so don't use the scheduler.step() machinery.
# Instead, we loop manually and call get_lr_at_pos() directly
lrs = []
for i in range(n):
pos = i / n
if self.units == "steps":
pos *= self.end
lrs.append(self.get_lr_at_pos(pos))
self.last_lr = 0
return lrs
def print_frames(self):
for frame in self.frames:
print(frame)
def get_linear_schedule_with_warmup(optimizer, lr_max, num_warmup_steps, num_training_steps, last_epoch=-1):
def lr_lambda(current_step):
learning_rate = max(0.0, 1.0 - (float(current_step) / float(num_training_steps)))
learning_rate *= lr_max * min(1.0, float(current_step) / float(num_warmup_steps))
return learning_rate
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_exponential_schedule_with_warmup(optimizer, lr_max, lr_end, num_warmup_steps, num_training_steps, last_epoch=-1):
scheduler = KeyframeLR(
optimizer=optimizer,
units="steps",
frames=[
{"position": 0, "lr": 0.0},
{"position": num_warmup_steps, "lr": lr_max},
{"transition": lambda last_lr, *_: last_lr * 0.999 + lr_end},
],
end=num_training_steps,
)
return scheduler
if __name__ == '__main__':
import matplotlib.pyplot as plt
lr_max = 1e-4
lr_end = 1e-5
power = 5.0
power = 1.0
num_warmup_steps = 4_000
num_training_steps = 10_000
model = AutoModel.from_pretrained("bert-base-uncased")
optimizer = torch.optim.Adam(model.parameters(), lr=lr_max)
# KeyframeLR
# scheduler = get_exponential_schedule_with_warmup(optimizer, lr_max=1e-4, lr_end=1e-6, num_warmup_steps=1000, num_training_steps=10000)
# transformers LR scheduler
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
scheduler = transformers.get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, lr_end=lr_end, power=power)
lrs = []
for i in range(num_training_steps):
optimizer.step()
lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()
plt.plot(lrs)
plt.show()