morris-bot / test_enhanced_model.py
eusholli's picture
Upload folder using huggingface_hub
599c2c0 verified
"""
Test script for the enhanced Iain Morris model
Tests diverse topics to ensure the model works beyond telecom
"""
import json
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_enhanced_model():
"""Test the enhanced model on diverse topics"""
# Test prompts covering diverse topics (non-telecom)
test_prompts = [
"Write about the absurdity of modern dating apps",
"Describe the chaos of working from home",
"Write about social media and its impact on society",
"Discuss the wellness industry and its promises",
"Write about the state of modern air travel",
"Analyze the gig economy",
"Write about student debt and higher education",
"Discuss the housing market crisis",
# Also test some telecom topics to ensure we didn't lose that capability
"Write about 5G network rollout challenges",
"Discuss the latest smartphone technology trends"
]
# Check if enhanced training data exists
try:
with open('data/enhanced_train_dataset.json', 'r') as f:
enhanced_data = json.load(f)
logger.info(f"Enhanced dataset contains {len(enhanced_data)} examples")
# Analyze the dataset composition
telecom_count = 0
non_telecom_count = 0
for example in enhanced_data:
content = str(example).lower()
if any(term in content for term in ['telecom', '5g', 'network', 'operator', 'spectrum']):
telecom_count += 1
else:
non_telecom_count += 1
logger.info(f"Dataset composition: {telecom_count} telecom examples, {non_telecom_count} non-telecom examples")
# Check system prompt
if enhanced_data:
system_prompt = enhanced_data[0]['messages'][0]['content']
logger.info("System prompt preview:")
logger.info(system_prompt[:200] + "...")
# Check for key improvements
improvements = [
"PROVOCATIVE DOOM-LADEN OPENINGS",
"SIGNATURE DARK ANALOGIES",
"CYNICAL WIT & EXPERTISE",
"What could possibly go wrong?"
]
for improvement in improvements:
if improvement in system_prompt:
logger.info(f"βœ“ Found improvement: {improvement}")
else:
logger.warning(f"βœ— Missing improvement: {improvement}")
except FileNotFoundError:
logger.error("Enhanced training dataset not found!")
return False
# Check if improved validation data exists
try:
with open('data/improved_val_dataset.json', 'r') as f:
val_data = json.load(f)
logger.info(f"Improved validation dataset contains {len(val_data)} examples")
except FileNotFoundError:
logger.warning("Improved validation dataset not found, using original")
logger.info("Enhanced model validation completed successfully!")
return True
def test_model_loading():
"""Test if the model can be loaded properly"""
try:
# Test base model loading
model_name = "HuggingFaceH4/zephyr-7b-beta"
logger.info(f"Testing model loading: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
logger.info("βœ“ Tokenizer loaded successfully")
# Test device detection
if torch.backends.mps.is_available():
device = "mps"
logger.info("βœ“ MPS (Apple Silicon) acceleration available")
elif torch.cuda.is_available():
device = "cuda"
logger.info("βœ“ CUDA acceleration available")
else:
device = "cpu"
logger.info("βœ“ Using CPU (no acceleration)")
logger.info(f"Will use device: {device}")
return True
except Exception as e:
logger.error(f"Model loading test failed: {e}")
return False
def main():
"""Run all tests"""
logger.info("Starting enhanced model tests...")
# Test 1: Enhanced dataset validation
logger.info("\n=== Test 1: Enhanced Dataset Validation ===")
dataset_ok = test_enhanced_model()
# Test 2: Model loading
logger.info("\n=== Test 2: Model Loading Test ===")
loading_ok = test_model_loading()
# Summary
logger.info("\n=== Test Summary ===")
if dataset_ok and loading_ok:
logger.info("βœ“ All tests passed! Ready for enhanced training.")
return True
else:
logger.error("βœ— Some tests failed. Check the issues above.")
return False
if __name__ == "__main__":
success = main()
exit(0 if success else 1)