import torch PRECISION_TO_TYPE = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, }