loocorez's picture
Upload app.py with huggingface_hub
178677d verified
import os
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers")
from transformers import AutoModel
from huggingface_hub import hf_hub_download
import torch
import gradio as gr
import pickle
MODEL_ID = "loocorez/nanochat-base-d20-test"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model via Auto* with trust_remote_code
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
model = model.to(device)
model.eval()
# Load tokenizer.pkl directly (avoid AutoTokenizer mapping issues)
tok_path = hf_hub_download(MODEL_ID, filename="tokenizer.pkl")
class PklTokenizer:
def __init__(self, pkl_file):
with open(pkl_file, "rb") as f:
self.enc = pickle.load(f)
self._bos = self.enc.encode_single_token("<|bos|>")
def get_bos_token_id(self):
return self._bos
def encode(self, text, prepend=None):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids = [prepend] + ids
return ids
def decode(self, ids):
return self.enc.decode(ids)
tokenizer = PklTokenizer(tok_path)
def complete(prompt, max_new_tokens=64):
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
ids = torch.tensor([input_ids], dtype=torch.long, device=device)
with torch.inference_mode():
for _ in range(max_new_tokens):
outputs = model(input_ids=ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
ids = torch.cat([ids, next_token], dim=1)
return tokenizer.decode(ids[0].tolist())
with gr.Blocks() as demo:
gr.Markdown("# NanoChat Transformers Demo (BASE d20)")
inp = gr.Textbox(value="The capital of Belgium is ")
max_toks = gr.Slider(1, 256, value=64, step=1, label="Max new tokens")
out = gr.Textbox()
btn = gr.Button("Generate")
btn.click(complete, [inp, max_toks], [out])
demo.launch()