import gc, json, torch, gradio as gr from huggingface_hub import hf_hub_download import tiktoken from mingpt.model import GPT DEVICE = "cuda" if torch.cuda.is_available() else "cpu" REPO_ID = "to0ony/final-thesis-plotgen" state = {"model": None, "model_name": None, "enc": tiktoken.get_encoding("gpt2")} def load_model(model_name): if state["model"] is not None and state["model_name"] == model_name: return state["model"] cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") mdl_path = hf_hub_download(repo_id=REPO_ID, filename=model_name) with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) gcfg = GPT.get_default_config() gcfg.model_type = None gcfg.vocab_size = int(cfg["vocab_size"]) gcfg.block_size = int(cfg["block_size"]) gcfg.n_layer = int(cfg["n_layer"]) gcfg.n_head = int(cfg["n_head"]) gcfg.n_embd = int(cfg["n_embd"]) model = GPT(gcfg) sd = torch.load(mdl_path, map_location="cpu", weights_only=False) model.load_state_dict(sd["model_state_dict"], strict=True) model.to(DEVICE) model.eval() state["model"] = model state["model_name"] = model_name return model @torch.inference_mode() def generate(prompt, model_choice, max_new_tokens=200, temperature=0.7, top_k=50): """Generiranje teksta iz prompta""" model = load_model(model_choice) enc = state["enc"] x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) y = model.generate( x, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_k=int(top_k) if top_k > 0 else None, do_sample=True ) return enc.decode(y[0].tolist()) # Gradio UI with gr.Blocks(title="🎬 PlotGen") as demo: gr.Markdown("## 🎬 PlotGen\nUnesi prompt i generiraj radnju filma.") model_choice = gr.Dropdown( choices=["cmu-plots-model.pt", "cmu-plots-model-enchanced.pt"], value="cmu-plots-model-enchanced.pt", label="Model" ) prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...") max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens") temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)") btn = gr.Button("Generate") output = gr.Textbox(label="Output", lines=15) btn.click(generate, [prompt, model_choice, max_new_tokens, temperature, top_k], output) if __name__ == "__main__": demo.launch()