| import torch |
| import random |
| import numpy as np |
| from typing import List |
| from itertools import repeat |
|
|
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def grad_status(model): |
| return (par.requires_grad for par in model.parameters()) |
|
|
|
|
| def lmap(f, x): |
| """list(map(f, x))""" |
| return list(map(f, x)) |
|
|
|
|
| def assert_all_frozen(model): |
| model_grads: List[bool] = list(grad_status(model)) |
| n_require_grad = sum(lmap(int, model_grads)) |
| npars = len(model_grads) |
| assert not any( |
| model_grads |
| ), f"{n_require_grad / npars:.1%} of {npars} weights require grad" |
|
|
|
|
| def split_dense_inputs(model_input: dict, chunk_size: int): |
| assert len(model_input) == 1 |
| arg_key = list(model_input.keys())[0] |
| arg_val = model_input[arg_key] |
|
|
| keys = list(arg_val.keys()) |
| chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] |
| chunked_arg_val = [ |
| dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors)) |
| ] |
|
|
| return [{arg_key: c} for c in chunked_arg_val] |
|
|
|
|
| def get_dense_rep(x): |
| if x.q_reps is None: |
| return x.p_reps |
| else: |
| return x.q_reps |
|
|