况兑 commited on
Commit
b4c6154
·
1 Parent(s): a769d64

stabilize: cache to /data + bool attn_mask + minimal app

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app_min.py +68 -0
  3. requirements.txt +4 -3
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.0
8
- app_file: loss_probe.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.0
8
+ app_file: app_min.py
9
  pinned: false
10
  license: mit
11
  ---
app_min.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+
4
+ # 持久化缓存到 /data
5
+ BASE = "/data"
6
+ os.makedirs(BASE, exist_ok=True)
7
+ os.environ.setdefault("HF_HOME", f"{BASE}/hf_home")
8
+ os.environ.setdefault("HF_HUB_CACHE", f"{BASE}/hf_home/hub")
9
+ os.environ.setdefault("TRANSFORMERS_CACHE", f"{BASE}/hf_home/transformers")
10
+ os.environ.setdefault("XDG_CACHE_HOME", f"{BASE}/hf_home")
11
+
12
+ MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B")
13
+ REV = os.getenv("REV", None)
14
+
15
+ print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
16
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ print("[INFO] Loading tokenizer...")
20
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
21
+
22
+ print("[INFO] Loading model...")
23
+ model = AutoModel.from_pretrained(
24
+ MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV
25
+ ).to(device).eval()
26
+
27
+ def quick_infer(q: str):
28
+ if not q.strip(): return ""
29
+ messages = [{"role": "user", "content": q}]
30
+ enc = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
31
+ input_ids = enc.input_ids.to(device)
32
+ attention_mask = enc.attention_mask.to(device).bool() # 关键:转成 bool
33
+ with torch.no_grad():
34
+ out = model.diffusion_generate(
35
+ input_ids,
36
+ attention_mask=attention_mask,
37
+ max_new_tokens=64,
38
+ steps=64,
39
+ temperature=0.0,
40
+ return_dict_in_generate=True,
41
+ )
42
+ text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
43
+ return text
44
+
45
+ def self_check():
46
+ try:
47
+ msgs = [{"role":"system","content":"只输出一个数字"},{"role":"user","content":"Compute: 1+1"}]
48
+ enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False)
49
+ _ = model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device).bool())
50
+ return "OK: forward() 可用(Dream 未必提供 labels->loss,属正常)"
51
+ except Exception as e:
52
+ return f"ERR: {repr(e)}"
53
+
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("## Dream Minimal App \n- 先点 Self-check \n- 再试一次推理")
56
+ with gr.Row():
57
+ btn = gr.Button("Self-check")
58
+ out = gr.Textbox(label="Result")
59
+ btn.click(fn=self_check, inputs=None, outputs=out)
60
+
61
+ with gr.Row():
62
+ q = gr.Textbox(label="Prompt", value="Compute: 1+1")
63
+ a = gr.Textbox(label="Output")
64
+ go = gr.Button("Generate")
65
+ go.click(fn=quick_infer, inputs=q, outputs=a)
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch()
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- gradio>=4.44.0
2
  transformers==4.46.2
 
 
3
  accelerate>=1.0.0
4
- bitsandbytes>=0.43.1
5
  huggingface_hub>=0.25.0
6
- torch>=2.4.0
 
 
 
1
  transformers==4.46.2
2
+ torch==2.5.1
3
+ gradio==5.49.0
4
  accelerate>=1.0.0
 
5
  huggingface_hub>=0.25.0
6
+ httpx[socks]
7
+ socksio