""" Simple Inference Script for TinyLlama This script demonstrates how to use a fine-tuned TinyLlama model for text generation without requiring all the training dependencies. """ import os import argparse import json import time def parse_args(): parser = argparse.ArgumentParser(description="Run inference with a TinyLlama model") parser.add_argument( "--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", help="Path to the model directory or HuggingFace model name" ) parser.add_argument( "--prompt", type=str, default=None, help="Text prompt for generation" ) parser.add_argument( "--prompt_file", type=str, default=None, help="File containing multiple prompts (one per line)" ) parser.add_argument( "--max_new_tokens", type=int, default=256, help="Maximum number of tokens to generate" ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument( "--output_file", type=str, default="generated_outputs.json", help="File to save generated outputs" ) parser.add_argument( "--interactive", action="store_true", help="Run in interactive mode" ) return parser.parse_args() def format_prompt_for_chat(prompt): """Format a prompt for chat completion""" return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" def main(): args = parse_args() try: # Import libraries here to make the error messages clearer import torch from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: print("Error: Required libraries not installed.") print("Please install them with: pip install torch transformers") return print(f"Loading model from {args.model_path}...") # Load model and tokenizer try: model = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True ) tokenizer = AutoTokenizer.from_pretrained(args.model_path) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() print(f"Model loaded successfully on {device}") except Exception as e: print(f"Error loading model: {e}") return if args.interactive: print("\n=== Interactive Mode ===") print("Type 'exit' or 'quit' to end the session") print("Type your prompts and press Enter.\n") while True: user_input = input("\nYou: ") if user_input.lower() in ["exit", "quit"]: break # Format prompt for chat formatted_prompt = format_prompt_for_chat(user_input) # Tokenize input inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) # Generate response start_time = time.time() with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the assistant's response try: # This handles the case where the model properly follows the formatting assistant_response = full_response.split("<|im_start|>assistant\n")[1].split("<|im_end|>")[0] except: # Fallback for when the model doesn't follow formatting perfectly assistant_response = full_response.replace(user_input, "").strip() gen_time = time.time() - start_time tokens_per_second = len(outputs[0]) / gen_time print(f"\nAssistant: {assistant_response}") print(f"\n[Generated {len(outputs[0])} tokens in {gen_time:.2f}s - {tokens_per_second:.2f} tokens/s]") else: # Get prompts prompts = [] if args.prompt: prompts.append(args.prompt) elif args.prompt_file: with open(args.prompt_file, 'r', encoding='utf-8') as f: prompts = [line.strip() for line in f if line.strip()] else: print("Error: Either --prompt or --prompt_file must be provided") return results = [] print(f"Processing {len(prompts)} prompts...") for i, prompt in enumerate(prompts): print(f"Processing prompt {i+1}/{len(prompts)}") # Format prompt for chat formatted_prompt = format_prompt_for_chat(prompt) # Tokenize input inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) # Generate response with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the assistant's response try: assistant_response = full_response.split("<|im_start|>assistant\n")[1].split("<|im_end|>")[0] except: assistant_response = full_response.replace(prompt, "").strip() results.append({ "prompt": prompt, "response": assistant_response }) # Save results with open(args.output_file, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"Generated {len(results)} responses and saved to {args.output_file}") if __name__ == "__main__": main()