Spaces:
Running
on
Zero
Running
on
Zero
File size: 702 Bytes
841f290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import f5_tts
from f5_tts.model.backbones.dit_mask import DiT as DiT_
_GPU_FM_TORCH_COMPILE = True
class GPUDiT(DiT_):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fast_forward = torch.compile(self.fast_forward, dynamic=False, fullgraph=True) \
if _GPU_FM_TORCH_COMPILE else self.fast_forward
# ===================================================================
print("========================= DO FM PATCH ============================")
# ===================================================================
f5_tts.model.backbones.dit_mask.DiT = GPUDiT |