Spaces:
Running
Running
# Copyright Alibaba Inc. All Rights Reserved. | |
from transformers import Wav2Vec2Model, Wav2Vec2Processor | |
from .model import FantasyTalkingAudioConditionModel | |
from .utils import get_audio_features | |
import gc, torch | |
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): | |
fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) | |
from mmgp import offload | |
from accelerate import init_empty_weights | |
from fantasytalking.model import AudioProjModel | |
torch.set_grad_enabled(False) | |
with init_empty_weights(): | |
proj_model = AudioProjModel( 768, 2048) | |
offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors") | |
proj_model.to("cpu").eval().requires_grad_(False) | |
wav2vec_model_dir = "ckpts/wav2vec" | |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) | |
wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False) | |
wav2vec.to(device) | |
proj_model.to(device) | |
audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames ) | |
audio_proj_fea = proj_model(audio_wav2vec_fea) | |
pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames ) | |
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768] | |
wav2vec, proj_model= None, None | |
gc.collect() | |
torch.cuda.empty_cache() | |
return audio_proj_split, audio_context_lens |