from ssllm_hf import SSLLMForCausalLM, SSLLMConfig import tiktoken import torch from safetensors.torch import load_file from huggingface_hub import hf_hub_download # Initialize model with config config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf') model = SSLLMForCausalLM(config) # Download and load model weights model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors') state_dict = load_file(model_path) model.load_state_dict(state_dict, strict=False) # Setup device and eval mode device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device).eval() # Initialize tokenizer tokenizer = tiktoken.get_encoding('cl100k_base') def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40): # Encode the prompt input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) attention_mask = torch.ones_like(input_ids) # Generate with the model with torch.no_grad(): outputs = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, pad_token_id=100257, eos_token_id=100257, ) # Decode only the new tokens new_tokens = outputs[0][input_ids.shape[1]:].tolist() generated = tokenizer.decode(new_tokens) print(f"{prompt}{generated}") print(f"\nTokens generated: {len(new_tokens)}") if __name__ == "__main__": prompt = "In a small village nestled between mountains," print(f"PROMPT: {prompt}\n--") generate_text(prompt)