03022ee
1
2
3
import torch dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}