|
import os |
|
os.environ.setdefault("HF_HOME", "/tmp/hf") |
|
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/hub") |
|
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf/transformers") |
|
|
|
from transformers import AutoModel |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
import gradio as gr |
|
import pickle |
|
|
|
MODEL_ID = "loocorez/nanochat-base-d20-test" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
tok_path = hf_hub_download(MODEL_ID, filename="tokenizer.pkl") |
|
|
|
class PklTokenizer: |
|
def __init__(self, pkl_file): |
|
with open(pkl_file, "rb") as f: |
|
self.enc = pickle.load(f) |
|
self._bos = self.enc.encode_single_token("<|bos|>") |
|
def get_bos_token_id(self): |
|
return self._bos |
|
def encode(self, text, prepend=None): |
|
ids = self.enc.encode_ordinary(text) |
|
if prepend is not None: |
|
ids = [prepend] + ids |
|
return ids |
|
def decode(self, ids): |
|
return self.enc.decode(ids) |
|
|
|
tokenizer = PklTokenizer(tok_path) |
|
|
|
def complete(prompt, max_new_tokens=64): |
|
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id()) |
|
ids = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
with torch.inference_mode(): |
|
for _ in range(max_new_tokens): |
|
outputs = model(input_ids=ids) |
|
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits |
|
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) |
|
ids = torch.cat([ids, next_token], dim=1) |
|
return tokenizer.decode(ids[0].tolist()) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# NanoChat Transformers Demo (BASE d20)") |
|
inp = gr.Textbox(value="The capital of Belgium is ") |
|
max_toks = gr.Slider(1, 256, value=64, step=1, label="Max new tokens") |
|
out = gr.Textbox() |
|
btn = gr.Button("Generate") |
|
btn.click(complete, [inp, max_toks], [out]) |
|
|
|
demo.launch() |
|
|
|
|
|
|