Spaces:
Running
Running
import os | |
import tempfile | |
import time | |
from typing import List, Tuple | |
import gradio as gr | |
import torch | |
import torchaudio | |
import spaces | |
from dataclasses import dataclass | |
from generator import Segment, load_csm_1b | |
from huggingface_hub import login | |
# Tắt tính năng compile của torch để tránh lỗi triton | |
torch._dynamo.config.suppress_errors = True | |
# Kiểm tra xem có GPU không và cấu hình thiết bị phù hợp | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Sử dụng thiết bị: {device}") | |
# Đăng nhập vào Hugging Face Hub nếu có token | |
def login_huggingface(): | |
hf_token = os.environ.get("HF_TOKEN") | |
if hf_token: | |
print("Đang đăng nhập vào Hugging Face Hub...") | |
login(token=hf_token) | |
print("Đã đăng nhập thành công!") | |
else: | |
print("Không tìm thấy HF_TOKEN trong biến môi trường. Một số mô hình có thể không truy cập được.") | |
# Đăng nhập khi khởi động | |
login_huggingface() | |
# Biến toàn cục để theo dõi trạng thái mô hình | |
generator = None | |
model_loaded = False | |
# Hàm tải mô hình được gọi trong ZeroGPU | |
def initialize_model(): | |
global generator, model_loaded | |
if not model_loaded: | |
print("Đang tải mô hình CSM-1B trong GPU...") | |
generator = load_csm_1b(device="cuda") | |
model_loaded = True | |
print("Đã tải xong mô hình!") | |
return generator | |
# Hàm lấy mô hình đã tải | |
def get_model(): | |
global generator, model_loaded | |
if not model_loaded: | |
return initialize_model() | |
return generator | |
# Hàm chuyển đổi âm thanh thành tensor | |
def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
waveform = waveform.mean(dim=0) # Chuyển stereo thành mono nếu cần | |
return waveform, sample_rate | |
# Hàm lưu tensor âm thanh thành file | |
def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str: | |
temp_dir = tempfile.gettempdir() | |
output_path = os.path.join(temp_dir, f"csm1b_output_{int(time.time())}.wav") | |
torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate) | |
return output_path | |
# Hàm tạo âm thanh từ văn bản sử dụng ZeroGPU | |
def generate_speech( | |
text: str, | |
speaker_id: int, | |
context_audio_path1: str = None, | |
context_text1: str = None, | |
context_speaker1: int = 0, | |
context_audio_path2: str = None, | |
context_text2: str = None, | |
context_speaker2: int = 1, | |
max_duration_ms: float = 30000, | |
temperature: float = 0.9, | |
top_k: int = 50, | |
progress=gr.Progress() | |
) -> str: | |
# Lấy mô hình đã tải | |
generator = get_model() | |
# Chuẩn bị ngữ cảnh (context) | |
context = [] | |
progress(0.1, "Đang xử lý ngữ cảnh...") | |
# Xử lý ngữ cảnh 1 | |
if context_audio_path1 and context_text1: | |
waveform, sample_rate = audio_to_tensor(context_audio_path1) | |
# Resample nếu cần | |
if sample_rate != generator.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
context.append(Segment(speaker=context_speaker1, text=context_text1, audio=waveform)) | |
# Xử lý ngữ cảnh 2 | |
if context_audio_path2 and context_text2: | |
waveform, sample_rate = audio_to_tensor(context_audio_path2) | |
# Resample nếu cần | |
if sample_rate != generator.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
context.append(Segment(speaker=context_speaker2, text=context_text2, audio=waveform)) | |
progress(0.3, "Đang tạo âm thanh...") | |
# Tạo âm thanh từ văn bản | |
audio = generator.generate( | |
text=text, | |
speaker=speaker_id, | |
context=context, | |
max_audio_length_ms=max_duration_ms, | |
temperature=temperature, | |
topk=top_k | |
) | |
progress(0.8, "Đang lưu âm thanh...") | |
# Lưu âm thanh thành file | |
output_path = save_audio(audio, generator.sample_rate) | |
progress(1.0, "Hoàn thành!") | |
return output_path | |
# Hàm tạo âm thanh đơn giản không có ngữ cảnh | |
def generate_speech_simple( | |
text: str, | |
speaker_id: int, | |
max_duration_ms: float = 30000, | |
temperature: float = 0.9, | |
top_k: int = 50, | |
progress=gr.Progress() | |
) -> str: | |
# Lấy mô hình đã tải | |
generator = get_model() | |
progress(0.3, "Đang tạo âm thanh...") | |
# Tạo âm thanh từ văn bản | |
audio = generator.generate( | |
text=text, | |
speaker=speaker_id, | |
context=[], # Không có ngữ cảnh | |
max_audio_length_ms=max_duration_ms, | |
temperature=temperature, | |
topk=top_k | |
) | |
progress(0.8, "Đang lưu âm thanh...") | |
# Lưu âm thanh thành file | |
output_path = save_audio(audio, generator.sample_rate) | |
progress(1.0, "Hoàn thành!") | |
return output_path | |
# Tạo giao diện Gradio | |
def create_demo(): | |
with gr.Blocks(title="CSM-1B Text-to-Speech") as demo: | |
gr.Markdown("# CSM-1B Text-to-Speech Demo") | |
gr.Markdown("Mô hình CSM-1B (Collaborative Speech Model) là một mô hình text-to-speech tiên tiến có khả năng tạo giọng nói tự nhiên từ văn bản.") | |
with gr.Tab("Tạo âm thanh đơn giản"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Văn bản cần chuyển thành giọng nói", | |
placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...", | |
lines=5 | |
) | |
speaker_id = gr.Number( | |
label="ID người nói", | |
value=0, | |
precision=0, | |
minimum=0, | |
maximum=10 | |
) | |
with gr.Row(): | |
max_duration = gr.Slider( | |
label="Thời lượng tối đa (ms)", | |
minimum=1000, | |
maximum=90000, | |
value=30000, | |
step=1000 | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.5, | |
value=0.9, | |
step=0.1 | |
) | |
top_k = gr.Slider( | |
label="Top-K", | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1 | |
) | |
generate_btn = gr.Button("Tạo âm thanh") | |
with gr.Column(): | |
output_audio = gr.Audio(label="Âm thanh đầu ra", type="filepath") | |
with gr.Tab("Tạo âm thanh với ngữ cảnh"): | |
gr.Markdown("Tính năng này cho phép bạn cung cấp các đoạn âm thanh và văn bản làm ngữ cảnh để mô hình tạo ra âm thanh phù hợp hơn.") | |
with gr.Row(): | |
with gr.Column(): | |
context_text1 = gr.Textbox(label="Văn bản ngữ cảnh 1", lines=2) | |
context_audio1 = gr.Audio(label="Âm thanh ngữ cảnh 1", type="filepath") | |
context_speaker1 = gr.Number(label="ID người nói 1", value=0, precision=0) | |
context_text2 = gr.Textbox(label="Văn bản ngữ cảnh 2", lines=2) | |
context_audio2 = gr.Audio(label="Âm thanh ngữ cảnh 2", type="filepath") | |
context_speaker2 = gr.Number(label="ID người nói 2", value=1, precision=0) | |
text_input_context = gr.Textbox( | |
label="Văn bản cần chuyển thành giọng nói", | |
placeholder="Nhập văn bản bạn muốn chuyển thành giọng nói...", | |
lines=3 | |
) | |
speaker_id_context = gr.Number( | |
label="ID người nói", | |
value=0, | |
precision=0 | |
) | |
with gr.Row(): | |
max_duration_context = gr.Slider( | |
label="Thời lượng tối đa (ms)", | |
minimum=1000, | |
maximum=90000, | |
value=30000, | |
step=1000 | |
) | |
temperature_context = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.5, | |
value=0.9, | |
step=0.1 | |
) | |
top_k_context = gr.Slider( | |
label="Top-K", | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1 | |
) | |
generate_context_btn = gr.Button("Tạo âm thanh với ngữ cảnh") | |
with gr.Column(): | |
output_audio_context = gr.Audio(label="Âm thanh đầu ra", type="filepath") | |
# Thêm tab cấu hình Hugging Face | |
with gr.Tab("Cấu hình"): | |
gr.Markdown("### Cấu hình Hugging Face Token") | |
gr.Markdown(""" | |
Để sử dụng mô hình CSM-1B, bạn cần có quyền truy cập vào mô hình trên Hugging Face. | |
Bạn có thể cấu hình token của mình bằng cách: | |
1. Tạo token tại [Hugging Face Settings](https://huggingface.co/settings/tokens) | |
2. Đặt biến môi trường `HF_TOKEN` với giá trị là token của bạn | |
Lưu ý: Trong Hugging Face Spaces, bạn có thể đặt biến môi trường trong phần Cài đặt của Space. | |
""") | |
hf_token_input = gr.Textbox( | |
label="Hugging Face Token (Chỉ sử dụng trong phiên này)", | |
placeholder="Nhập token của bạn...", | |
type="password" | |
) | |
def set_token(token): | |
if token: | |
os.environ["HF_TOKEN"] = token | |
login(token=token) | |
return "Đã đặt token thành công! Bạn có thể tải mô hình bây giờ." | |
return "Token không hợp lệ. Vui lòng nhập token hợp lệ." | |
set_token_btn = gr.Button("Đặt Token") | |
token_status = gr.Textbox(label="Trạng thái", interactive=False) | |
set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status) | |
# Thêm tab thông tin về ZeroGPU | |
with gr.Tab("Thông tin GPU"): | |
gr.Markdown("### Thông tin về ZeroGPU") | |
gr.Markdown(""" | |
Ứng dụng này sử dụng ZeroGPU của Hugging Face Spaces để tối ưu việc sử dụng GPU. | |
ZeroGPU giúp giải phóng bộ nhớ GPU khi không sử dụng, giúp tiết kiệm tài nguyên và cải thiện hiệu suất. | |
Khi bạn tạo âm thanh, GPU sẽ được sử dụng tự động và giải phóng sau khi hoàn thành. | |
Lưu ý: Trong môi trường ZeroGPU, CUDA không được khởi tạo trong quá trình chính, mà chỉ trong các hàm có decorator @spaces.GPU. | |
""") | |
def check_gpu(): | |
if torch.cuda.is_available(): | |
gpu_name = torch.cuda.get_device_name(0) | |
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
return f"GPU: {gpu_name}\nBộ nhớ: {gpu_memory:.2f} GB" | |
else: | |
return "Không tìm thấy GPU. Ứng dụng sẽ chạy trên CPU." | |
check_gpu_btn = gr.Button("Kiểm tra GPU") | |
gpu_info = gr.Textbox(label="Thông tin GPU", interactive=False) | |
check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info) | |
# Thêm nút tải mô hình | |
load_model_btn = gr.Button("Tải mô hình") | |
model_status = gr.Textbox(label="Trạng thái mô hình", interactive=False) | |
def load_model_and_report(): | |
global model_loaded | |
if model_loaded: | |
return "Mô hình đã được tải trước đó!" | |
else: | |
initialize_model() | |
return "Mô hình đã được tải thành công!" | |
load_model_btn.click(fn=load_model_and_report, inputs=None, outputs=model_status) | |
# Kết nối các thành phần | |
generate_btn.click( | |
fn=generate_speech_simple, | |
inputs=[ | |
text_input, | |
speaker_id, | |
max_duration, | |
temperature, | |
top_k | |
], | |
outputs=output_audio | |
) | |
generate_context_btn.click( | |
fn=generate_speech, | |
inputs=[ | |
text_input_context, | |
speaker_id_context, | |
context_audio1, | |
context_text1, | |
context_speaker1, | |
context_audio2, | |
context_text2, | |
context_speaker2, | |
max_duration_context, | |
temperature_context, | |
top_k_context | |
], | |
outputs=output_audio_context | |
) | |
return demo | |
# Khởi chạy ứng dụng | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue().launch() |