import torch import csv import json import re from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset from tqdm import tqdm from warnings import filterwarnings filterwarnings("ignore") # Model setup model_id = "meta-llama/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") # Load subset of GSM8K dataset for debugging dataset = load_dataset("gsm8k", "main", split="train[:100]") # Output CSV csv_file = "gsm8k_llama3_results_1.csv" file = open(csv_file, mode='w', newline='', encoding='utf-8') writer = csv.writer(file) writer.writerow(["question", "true_answer", "predicted_answer", "full_response"]) # Inference loop for idx, example in enumerate(tqdm(dataset, desc="Evaluating")): question = example["question"] true_answer = example["answer"].split("####")[-1].strip() # Better prompting with fixed answer format prompt = ( f"Question: {question}\n\n" "Please solve this step-by-step and finally answer in this format:\n" "Answer: \n" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=300, temperature=0.7, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Better extraction: look for "Answer: " pattern first match = re.search(r"Answer:\s*([-+]?\d*\.?\d+)", response) if match: predicted_answer = match.group(1) else: # Fallback: extract all numbers, take most frequent or last pred_numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response) predicted_answer = pred_numbers[-1] if pred_numbers else "N/A" # Print few examples for debugging if idx < 5: print("="*50) print(f"Question: {question}") print(f"Response: {response}") print(f"True Answer: {true_answer}") print(f"Predicted Answer: {predicted_answer}") # Write to CSV writer.writerow([question, true_answer, predicted_answer, response]) # Close CSV file file.close() print("Evaluation complete. Results saved to:", csv_file)