File size: 1,008 Bytes
762f6ed 54b9b5c 762f6ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | # -*- coding: utf-8 -*-
#Author: Lart Pang (https://github.com/lartpang)
import torch
import torch.nn.functional as F
def rescale_2x(x: torch.Tensor, scale_factor=2):
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
def resize_to(x: torch.Tensor, tgt_hw: tuple):
return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False)
def clip_grad(params, mode, clip_cfg: dict):
if mode == "norm":
if "max_norm" not in clip_cfg:
raise ValueError("`clip_cfg` must contain `max_norm`.")
torch.nn.utils.clip_grad_norm_(
params,
max_norm=clip_cfg.get("max_norm"),
norm_type=clip_cfg.get("norm_type", 2.0),
)
elif mode == "value":
if "clip_value" not in clip_cfg:
raise ValueError("`clip_cfg` must contain `clip_value`.")
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
else:
raise NotImplementedError
|