import gradio as gr import spaces import torch from transformers import AutoConfig, AutoModel, AutoTokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): model_dir = "nanochat-students/base-d20" # Load model via Transformers Auto classes config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) # Set device explicitly # Load model and move to device # Use low_cpu_mem_usage=False to avoid meta device issues model = AutoModel.from_pretrained( model_dir, trust_remote_code=True, low_cpu_mem_usage=False ) model = model.to(device) model.eval() # Load tokenizer via AutoTokenizer (trust_remote_code uses tokenizer_nanogpt) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, config=config) return tokenizer, model tokenizer, model = load_model() @spaces.GPU def generate(prompt): input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id()) ids = torch.tensor([input_ids], dtype=torch.long, device=device) max_new_tokens = 50 with torch.inference_mode(): for _ in range(max_new_tokens): outputs = model(input_ids=ids) logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits # Only take the logits for the last token next_token_logits = logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) ids = torch.cat([ids, next_token], dim=1) # Optional: Add early stopping on EOS token # if next_token.item() == eos_token_id: # break decoded = tokenizer.decode(ids[0].tolist()) return decoded gr.Interface( fn=generate, inputs=gr.Text(), outputs=gr.Text(), ).launch()