loocorez commited on
Commit
178677d
·
verified ·
1 Parent(s): 0507169

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("HF_HOME", "/tmp/hf")
3
+ os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
4
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
5
+
6
+ from transformers import AutoModel
7
+ from huggingface_hub import hf_hub_download
8
+ import torch
9
+ import gradio as gr
10
+ import pickle
11
+
12
+ MODEL_ID = "loocorez/nanochat-base-d20-test"
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Load model via Auto* with trust_remote_code
17
+ model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
18
+ model = model.to(device)
19
+ model.eval()
20
+
21
+ # Load tokenizer.pkl directly (avoid AutoTokenizer mapping issues)
22
+ tok_path = hf_hub_download(MODEL_ID, filename="tokenizer.pkl")
23
+
24
+ class PklTokenizer:
25
+ def __init__(self, pkl_file):
26
+ with open(pkl_file, "rb") as f:
27
+ self.enc = pickle.load(f)
28
+ self._bos = self.enc.encode_single_token("<|bos|>")
29
+ def get_bos_token_id(self):
30
+ return self._bos
31
+ def encode(self, text, prepend=None):
32
+ ids = self.enc.encode_ordinary(text)
33
+ if prepend is not None:
34
+ ids = [prepend] + ids
35
+ return ids
36
+ def decode(self, ids):
37
+ return self.enc.decode(ids)
38
+
39
+ tokenizer = PklTokenizer(tok_path)
40
+
41
+ def complete(prompt, max_new_tokens=64):
42
+ input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
43
+ ids = torch.tensor([input_ids], dtype=torch.long, device=device)
44
+ with torch.inference_mode():
45
+ for _ in range(max_new_tokens):
46
+ outputs = model(input_ids=ids)
47
+ logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
48
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
49
+ ids = torch.cat([ids, next_token], dim=1)
50
+ return tokenizer.decode(ids[0].tolist())
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# NanoChat Transformers Demo (BASE d20)")
54
+ inp = gr.Textbox(value="The capital of Belgium is ")
55
+ max_toks = gr.Slider(1, 256, value=64, step=1, label="Max new tokens")
56
+ out = gr.Textbox()
57
+ btn = gr.Button("Generate")
58
+ btn.click(complete, [inp, max_toks], [out])
59
+
60
+ demo.launch()
61
+
62
+