OSUM-EChat / patches /modelling_fm_infer_gpu.py
xlgeng's picture
开始部署
841f290
raw
history blame contribute delete
702 Bytes
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import f5_tts
from f5_tts.model.backbones.dit_mask import DiT as DiT_
_GPU_FM_TORCH_COMPILE = True
class GPUDiT(DiT_):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fast_forward = torch.compile(self.fast_forward, dynamic=False, fullgraph=True) \
if _GPU_FM_TORCH_COMPILE else self.fast_forward
# ===================================================================
print("========================= DO FM PATCH ============================")
# ===================================================================
f5_tts.model.backbones.dit_mask.DiT = GPUDiT