MSRNet / utils /pt_utils.py
linaa98's picture
Update utils/pt_utils.py
24a1704 verified
#Author: Lart Pang (https://github.com/lartpang)
import logging
import os
import random
import numpy as np
import torch
from torch import nn
from torch.backends import cudnn
LOGGER = logging.getLogger("main")
def customized_worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
def set_seed_for_lib(seed):
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def initialize_seed_cudnn(seed, deterministic):
assert isinstance(deterministic, bool) and isinstance(seed, int)
if seed >= 0:
LOGGER.info(f"We will use a fixed seed {seed}")
else:
seed = np.random.randint(2**32)
LOGGER.info(f"We will use a random seed {seed}")
set_seed_for_lib(seed)
if not deterministic:
LOGGER.info("We will use `torch.backends.cudnn.benchmark`")
else:
LOGGER.info("We will not use `torch.backends.cudnn.benchmark`")
cudnn.enabled = True
cudnn.benchmark = not deterministic
cudnn.deterministic = deterministic
def to_device(data, device="cuda"):
if isinstance(data, (tuple, list)):
return [to_device(item, device) for item in data]
elif isinstance(data, dict):
return {name: to_device(item, device) for name, item in data.items()}
elif isinstance(data, torch.Tensor):
return data.to(device=device, non_blocking=True)
else:
raise TypeError(f"Unsupported type {type(data)}. Only support Tensor or tuple/list/dict containing Tensors.")
def frozen_bn_stats(model, freeze_affine=False):
"""
Set all the bn layers to eval mode.
Args:
model (model): model to set bn layers to eval mode.
"""
for m in model.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.eval()
if freeze_affine:
m.requires_grad_(False)