import os, torch, gradio as gr from transformers import AutoModel, AutoTokenizer MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B") REV = os.getenv("REV", None) print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}") tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV) dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV).to(device).eval() def check_loss(): msgs = [ {"role": "system", "content": "只输出一个数字"}, {"role": "user", "content": "Compute: 1+1"}, ] enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False) # 保证 dtype / device 正确;attention_mask 用 bool 可兼容 input_ids = enc["input_ids"].to(device) attn = enc.get("attention_mask", None) if attn is not None: attn = attn.to(device).to(torch.bool) labels = input_ids.clone() try: out = model(input_ids=input_ids, attention_mask=attn, labels=labels) has_loss = getattr(out, "loss", None) is not None return f"[CHECK] supports labels->loss? {has_loss} | type={type(out)}" except Exception as e: return f"[CHECK] raised: {repr(e)}" def quick_infer(q: str): if not q.strip(): return "" messages = [{"role": "user", "content": q}] inputs = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device).to(torch.bool) with torch.no_grad(): out = model.diffusion_generate( input_ids, attention_mask=attention_mask, max_new_tokens=64, steps=64, temperature=0.0, return_dict_in_generate=True, ) text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip() return text with gr.Blocks() as demo: gr.Markdown("## Dream Loss Probe\n- 点击 **Run self-check** 看是否有 `loss`\n- 右侧可用 `diffusion_generate` 试跑") with gr.Row(): check_btn = gr.Button("Run self-check") check_out = gr.Textbox(label="Result") check_btn.click(fn=check_loss, inputs=None, outputs=check_out) with gr.Row(): q = gr.Textbox(label="Quick inference prompt", value="Compute: 1+1") a = gr.Textbox(label="Model output") run = gr.Button("Generate (diffusion_generate)") run.click(fn=quick_infer, inputs=q, outputs=a) if __name__ == "__main__": demo.launch()