File size: 96 Bytes
03022ee
 
1
2
3
import torch
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}