GLEN-model / src /tevatron /main_utils.py
QuanTH02's picture
Phase 1
3d5551b
import torch
import random
import numpy as np
from typing import List
from itertools import repeat
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def grad_status(model):
return (par.requires_grad for par in model.parameters())
def lmap(f, x):
"""list(map(f, x))"""
return list(map(f, x))
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(
model_grads
), f"{n_require_grad / npars:.1%} of {npars} weights require grad"
def split_dense_inputs(model_input: dict, chunk_size: int):
assert len(model_input) == 1
arg_key = list(model_input.keys())[0]
arg_val = model_input[arg_key]
keys = list(arg_val.keys())
chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
chunked_arg_val = [
dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))
]
return [{arg_key: c} for c in chunked_arg_val]
def get_dense_rep(x):
if x.q_reps is None:
return x.p_reps
else:
return x.q_reps