csm-1b-gradio-v2 / test_model.py
A Le Thanh Son
fix
6d75162
raw
history blame
2.45 kB
import os
import torch
import torchaudio
import spaces
from generator import Segment, load_csm_1b
from huggingface_hub import login
def login_huggingface():
"""Đăng nhập vào Hugging Face Hub sử dụng token từ biến môi trường hoặc nhập từ người dùng"""
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
print("Không tìm thấy HF_TOKEN trong biến môi trường.")
hf_token = input("Vui lòng nhập Hugging Face token của bạn: ")
if hf_token:
print("Đang đăng nhập vào Hugging Face Hub...")
login(token=hf_token)
print("Đã đăng nhập thành công!")
return True
else:
print("Không có token. Một số mô hình có thể không truy cập được.")
return False
@spaces.GPU
def generate_test_audio(text, speaker_id, device):
"""Tạo âm thanh kiểm tra sử dụng ZeroGPU"""
generator = load_csm_1b(device=device)
print("Đã tải xong mô hình!")
print(f"Đang tạo âm thanh cho văn bản: '{text}'")
audio = generator.generate(
text=text,
speaker=speaker_id,
context=[],
max_audio_length_ms=10000,
temperature=0.9,
topk=50
)
return audio, generator.sample_rate
def test_model():
print("Kiểm tra mô hình CSM-1B...")
# Đăng nhập vào Hugging Face Hub
login_huggingface()
# Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Sử dụng thiết bị: {device}")
# Tải mô hình CSM-1B và tạo âm thanh
print("Đang tải mô hình CSM-1B...")
try:
# Sử dụng ZeroGPU để tạo âm thanh
text = "Xin chào, đây là bài kiểm tra mô hình CSM-1B."
speaker_id = 0
audio, sample_rate = generate_test_audio(text, speaker_id, device)
# Lưu âm thanh thành file
output_path = "test_output.wav"
torchaudio.save(output_path, audio.unsqueeze(0), sample_rate)
print(f"Đã lưu âm thanh vào file: {output_path}")
print("Kiểm tra hoàn tất!")
except Exception as e:
print(f"Lỗi khi kiểm tra mô hình: {e}")
print("Vui lòng kiểm tra lại token và quyền truy cập của bạn.")
if __name__ == "__main__":
test_model()