| | import math |
| |
|
| | import torch |
| |
|
| |
|
| | def identity(t, *args, **kwargs): |
| | """return t""" |
| | return t |
| |
|
| |
|
| | def exists(x): |
| | """whether x is None or not""" |
| | return x is not None |
| |
|
| |
|
| | def default(val, d): |
| | """ternary judgment: val != None ? val : d""" |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | def has_int_squareroot(num): |
| | return (math.sqrt(num) ** 2) == num |
| |
|
| |
|
| | def num_to_groups(num, divisor): |
| | groups = num // divisor |
| | remainder = num % divisor |
| | arr = [divisor] * groups |
| | if remainder > 0: |
| | arr.append(remainder) |
| | return arr |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def sum_params(model: torch.nn.Module, eps: float = 1e6): |
| | return sum(p.numel() for p in model.parameters()) / eps |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cycle(dl): |
| | while True: |
| | for data in dl: |
| | yield data |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def extract(a, t, x_shape): |
| | b, *_ = t.shape |
| | assert x_shape[0] == b |
| | out = a.gather(-1, t) |
| | return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| |
|
| |
|
| | def unnormalize(x): |
| | """unnormalize_to_zero_to_one""" |
| | x = (x + 1) * 0.5 |
| | return torch.clamp(x, 0.0, 1.0) |
| |
|
| |
|
| | def normalize(x): |
| | """normalize_to_neg_one_to_one""" |
| | x = x * 2 - 1 |
| | return torch.clamp(x, -1.0, 1.0) |
| |
|