File size: 2,724 Bytes
b4c6154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os, torch, gradio as gr
from transformers import AutoModel, AutoTokenizer

# 持久化缓存到 /data
BASE = "/data"
os.makedirs(BASE, exist_ok=True)
os.environ.setdefault("HF_HOME", f"{BASE}/hf_home")
os.environ.setdefault("HF_HUB_CACHE", f"{BASE}/hf_home/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", f"{BASE}/hf_home/transformers")
os.environ.setdefault("XDG_CACHE_HOME", f"{BASE}/hf_home")

MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B")
REV = os.getenv("REV", None)

print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
dtype  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"

print("[INFO] Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)

print("[INFO] Loading model...")
model = AutoModel.from_pretrained(
    MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV
).to(device).eval()

def quick_infer(q: str):
    if not q.strip(): return ""
    messages = [{"role": "user", "content": q}]
    enc = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
    input_ids = enc.input_ids.to(device)
    attention_mask = enc.attention_mask.to(device).bool()  # 关键:转成 bool
    with torch.no_grad():
        out = model.diffusion_generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=64,
            steps=64,
            temperature=0.0,
            return_dict_in_generate=True,
        )
    text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
    return text

def self_check():
    try:
        msgs = [{"role":"system","content":"只输出一个数字"},{"role":"user","content":"Compute: 1+1"}]
        enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False)
        _ = model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device).bool())
        return "OK: forward() 可用(Dream 未必提供 labels->loss,属正常)"
    except Exception as e:
        return f"ERR: {repr(e)}"

with gr.Blocks() as demo:
    gr.Markdown("## Dream Minimal App  \n- 先点 Self-check  \n- 再试一次推理")
    with gr.Row():
        btn = gr.Button("Self-check")
        out = gr.Textbox(label="Result")
    btn.click(fn=self_check, inputs=None, outputs=out)

    with gr.Row():
        q = gr.Textbox(label="Prompt", value="Compute: 1+1")
        a = gr.Textbox(label="Output")
        go = gr.Button("Generate")
    go.click(fn=quick_infer, inputs=q, outputs=a)

if __name__ == "__main__":
    demo.launch()