p1atdev's picture
Update app.py
26ea8db verified
try:
import flash_attn
except:
import subprocess
print("Installing flash-attn...")
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import flash_attn
print("flash-attn installed.")
import os
import time
import spaces
import torch
from transformers import (
AutoModelForPreTraining,
AutoProcessor,
AutoConfig,
PreTrainedTokenizerFast,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gradio as gr
MODEL_NAME = os.environ.get("MODEL_NAME", None)
assert MODEL_NAME is not None
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
DEVICE = (
torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda")
)
BAD_WORD_KEYWORDS = ["(medium)", " text", "(style)"]
def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
ids = [
[id]
for token, id in tokenizer.vocab.items()
if any(word in token for word in BAD_WORD_KEYWORDS)
]
return ids
def prepare_models():
model = AutoModelForPreTraining.from_pretrained(
MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
)
model.decoder_model.use_cache = True
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.eval()
model = model.to(DEVICE)
# model = torch.compile(model)
return model, processor
def demo():
model, processor = prepare_models()
ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
translation_mode_map = {
"translate": "exact",
"translate + extend": "approx",
}
@spaces.GPU(duration=5)
@torch.inference_mode()
def generate_tags(
text: str,
auto_detect: bool,
mode: str = "translate",
copyright_tags: str = "",
length: str = "short",
max_new_tokens: int = 128,
do_sample: bool = False,
temperature: float = 0.1,
top_k: int = 10,
top_p: float = 0.1,
):
tag_text = (
"<|bos|>"
f"<|aspect_ratio:tall|><|rating:general|><|length:{length}|>"
"<|reserved_2|><|reserved_3|><|reserved_4|>"
f"<|translate:{translation_mode_map[mode]}|><|input_end|>"
"<copyright>" + copyright_tags.strip()
)
if not auto_detect:
tag_text += "</copyright><character></character><general>"
inputs = processor(
encoder_text=text, decoder_text=tag_text, return_tensors="pt"
)
start_time = time.time()
outputs = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
encoder_input_ids=inputs["encoder_input_ids"].to(model.device),
encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device),
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
no_repeat_ngram_size=1,
eos_token_id=processor.decoder_tokenizer.eos_token_id,
pad_token_id=processor.decoder_tokenizer.pad_token_id,
bad_words_ids=ban_ids,
)
elapsed = time.time() - start_time
deocded = ", ".join(
[
tag
for tag in processor.batch_decode(outputs[0], skip_special_tokens=True)
if tag.strip() != ""
]
)
return [deocded, f"Time elapsed: {elapsed:.2f} seconds"]
# warmup
print("warming up...")
print(generate_tags("Hatsune Miku is looking at viewer.", True))
print("done.")
with gr.Blocks() as ui:
with gr.Column():
with gr.Row():
with gr.Column():
text = gr.Text(
label="Text",
info="Enter a prompt in natural language (currently only English is supported). But maybe danbooru tags are also supported.",
lines=4,
placeholder="A girl with fox ears and tail in maid costume is looking at viewer.",
)
auto_detect = gr.Checkbox(
label="Auto detect copyright tags.", value=False
)
copyright_tags = gr.Textbox(
label="Copyright tags",
info="You can specify copyright tags manually. This must be valid danbooru tags.",
placeholder="e.g.) vocaloid, blue archive",
)
length = gr.Dropdown(
label="Length",
choices=[
"very_short",
"short",
"long",
"very_long",
],
value="short",
)
translation_mode = gr.Radio(
label="Translation mode",
choices=list(translation_mode_map.keys()),
value=list(translation_mode_map.keys())[0],
)
translate_btn = gr.Button(value="Translate", variant="primary")
with gr.Accordion(label="Advanced", open=False):
max_new_tokens = gr.Number(label="Max new tokens", value=128)
do_sample = gr.Checkbox(label="Do sample", value=False)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
)
top_k = gr.Slider(
label="Top k",
minimum=1,
maximum=100,
value=10,
step=10,
)
top_p = gr.Slider(
label="Top p",
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.1,
)
with gr.Column():
output_translation = gr.Textbox(label="Output", lines=4, interactive=False)
# output_extension = gr.Textbox(label="Output (extension)", lines=4, interactive=False)
time_elapsed = gr.Markdown(value="")
gr.Examples(
examples=[
[
"猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。",
False,
"",
"very_short",
"translate",
],
[
"猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。",
False,
"",
"long",
"translate + extend",
],
[
"猫耳少女のポートレート。:3 ",
False,
"",
"very_short",
"translate + extend",
],
[
"学園アイドルマスター。ジャージを着た篠澤広が疲れ切っており、床に座って笑いながらこっちを見ている",
True,
"",
"short",
"translate",
],
[
"ガールズバンドクライの井芹ニナと桃華。シンプル背景。小指を立ててこっちを向いている。feet out of frame",
True,
"",
"long",
"translate + extend",
],
[
"夜の暗い路地で、黒い服に身を包んだ女がこっちを振り返っている。白いシャツとネクタイ、ジャケットに、手袋をしている",
False,
"",
"long",
"translate + extend",
],
[
"一人の少女の横顔で、全体的に赤い雰囲気。髪は肩までの長さで、横を向いている。",
False,
"",
"short",
"translate + extend",
],
[
"二人の少女がいる。一人は、blonde hair で long hair、もう一人は brown hair で short hair。二人とも制服。少なくとも片方はブレザーを着ている。場所は教室で、窓から日差しが差し込んでいる。cowboy shot。一人は机に座っていて、もう一人は立っている。",
False,
"",
"long",
"translate + extend",
],
],
inputs=[text, auto_detect, copyright_tags, length, translation_mode],
)
gr.on(
triggers=[
translate_btn.click,
],
fn=generate_tags,
inputs=[
text,
auto_detect,
translation_mode,
copyright_tags,
length,
max_new_tokens,
do_sample,
temperature,
top_k,
top_p,
],
outputs=[
output_translation,
# output_extension,
time_elapsed,
],
)
ui.launch()
if __name__ == "__main__":
demo()