ChatTS / app_legacy.py
xiezhe22's picture
Update ChatTS
b69bd77
import spaces # for ZeroGPU support
import gradio as gr
import pandas as pd
import numpy as np
import torch
import subprocess
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
)
# ─── MODEL SETUP ────────────────────────────────────────────────────────────────
MODEL_NAME = "bytedance-research/ChatTS-14B"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
# ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
@spaces.GPU
def generate_text(prompt):
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.2,
top_p=0.9
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
demo = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=2, label="Prompt"),
outputs=gr.Textbox(lines=6, label="Generated Text")
)
if __name__ == '__main__':
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
demo.launch()