try: import flash_attn except: import subprocess print("Installing flash-attn...") subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import flash_attn print("flash-attn installed.") import os import time import spaces import torch from transformers import ( AutoModelForPreTraining, AutoProcessor, AutoConfig, PreTrainedTokenizerFast, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file import gradio as gr MODEL_NAME = os.environ.get("MODEL_NAME", None) assert MODEL_NAME is not None MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") DEVICE = ( torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda") ) BAD_WORD_KEYWORDS = ["(medium)", " text", "(style)"] def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast): ids = [ [id] for token, id in tokenizer.vocab.items() if any(word in token for word in BAD_WORD_KEYWORDS) ] return ids def prepare_models(): model = AutoModelForPreTraining.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True ) model.decoder_model.use_cache = True processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) model.eval() model = model.to(DEVICE) # model = torch.compile(model) return model, processor def demo(): model, processor = prepare_models() ban_ids = get_bad_words_ids(processor.decoder_tokenizer) translation_mode_map = { "translate": "exact", "translate + extend": "approx", } @spaces.GPU(duration=5) @torch.inference_mode() def generate_tags( text: str, auto_detect: bool, mode: str = "translate", copyright_tags: str = "", length: str = "short", max_new_tokens: int = 128, do_sample: bool = False, temperature: float = 0.1, top_k: int = 10, top_p: float = 0.1, ): tag_text = ( "<|bos|>" f"<|aspect_ratio:tall|><|rating:general|><|length:{length}|>" "<|reserved_2|><|reserved_3|><|reserved_4|>" f"<|translate:{translation_mode_map[mode]}|><|input_end|>" "" + copyright_tags.strip() ) if not auto_detect: tag_text += "" inputs = processor( encoder_text=text, decoder_text=tag_text, return_tensors="pt" ) start_time = time.time() outputs = model.generate( input_ids=inputs["input_ids"].to(model.device), attention_mask=inputs["attention_mask"].to(model.device), encoder_input_ids=inputs["encoder_input_ids"].to(model.device), encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device), max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, no_repeat_ngram_size=1, eos_token_id=processor.decoder_tokenizer.eos_token_id, pad_token_id=processor.decoder_tokenizer.pad_token_id, bad_words_ids=ban_ids, ) elapsed = time.time() - start_time deocded = ", ".join( [ tag for tag in processor.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != "" ] ) return [deocded, f"Time elapsed: {elapsed:.2f} seconds"] # warmup print("warming up...") print(generate_tags("Hatsune Miku is looking at viewer.", True)) print("done.") with gr.Blocks() as ui: with gr.Column(): with gr.Row(): with gr.Column(): text = gr.Text( label="Text", info="Enter a prompt in natural language (currently only English is supported). But maybe danbooru tags are also supported.", lines=4, placeholder="A girl with fox ears and tail in maid costume is looking at viewer.", ) auto_detect = gr.Checkbox( label="Auto detect copyright tags.", value=False ) copyright_tags = gr.Textbox( label="Copyright tags", info="You can specify copyright tags manually. This must be valid danbooru tags.", placeholder="e.g.) vocaloid, blue archive", ) length = gr.Dropdown( label="Length", choices=[ "very_short", "short", "long", "very_long", ], value="short", ) translation_mode = gr.Radio( label="Translation mode", choices=list(translation_mode_map.keys()), value=list(translation_mode_map.keys())[0], ) translate_btn = gr.Button(value="Translate", variant="primary") with gr.Accordion(label="Advanced", open=False): max_new_tokens = gr.Number(label="Max new tokens", value=128) do_sample = gr.Checkbox(label="Do sample", value=False) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=1.0, value=0.3, step=0.1, ) top_k = gr.Slider( label="Top k", minimum=1, maximum=100, value=10, step=10, ) top_p = gr.Slider( label="Top p", minimum=0.1, maximum=1.0, value=0.5, step=0.1, ) with gr.Column(): output_translation = gr.Textbox(label="Output", lines=4, interactive=False) # output_extension = gr.Textbox(label="Output (extension)", lines=4, interactive=False) time_elapsed = gr.Markdown(value="") gr.Examples( examples=[ [ "猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。", False, "", "very_short", "translate", ], [ "猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。", False, "", "long", "translate + extend", ], [ "猫耳少女のポートレート。:3 ", False, "", "very_short", "translate + extend", ], [ "学園アイドルマスター。ジャージを着た篠澤広が疲れ切っており、床に座って笑いながらこっちを見ている", True, "", "short", "translate", ], [ "ガールズバンドクライの井芹ニナと桃華。シンプル背景。小指を立ててこっちを向いている。feet out of frame", True, "", "long", "translate + extend", ], [ "夜の暗い路地で、黒い服に身を包んだ女がこっちを振り返っている。白いシャツとネクタイ、ジャケットに、手袋をしている", False, "", "long", "translate + extend", ], [ "一人の少女の横顔で、全体的に赤い雰囲気。髪は肩までの長さで、横を向いている。", False, "", "short", "translate + extend", ], [ "二人の少女がいる。一人は、blonde hair で long hair、もう一人は brown hair で short hair。二人とも制服。少なくとも片方はブレザーを着ている。場所は教室で、窓から日差しが差し込んでいる。cowboy shot。一人は机に座っていて、もう一人は立っている。", False, "", "long", "translate + extend", ], ], inputs=[text, auto_detect, copyright_tags, length, translation_mode], ) gr.on( triggers=[ translate_btn.click, ], fn=generate_tags, inputs=[ text, auto_detect, translation_mode, copyright_tags, length, max_new_tokens, do_sample, temperature, top_k, top_p, ], outputs=[ output_translation, # output_extension, time_elapsed, ], ) ui.launch() if __name__ == "__main__": demo()