Spaces:
Sleeping
Sleeping
""" | |
Test script for the enhanced Iain Morris model | |
Tests diverse topics to ensure the model works beyond telecom | |
""" | |
import json | |
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import torch | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def test_enhanced_model(): | |
"""Test the enhanced model on diverse topics""" | |
# Test prompts covering diverse topics (non-telecom) | |
test_prompts = [ | |
"Write about the absurdity of modern dating apps", | |
"Describe the chaos of working from home", | |
"Write about social media and its impact on society", | |
"Discuss the wellness industry and its promises", | |
"Write about the state of modern air travel", | |
"Analyze the gig economy", | |
"Write about student debt and higher education", | |
"Discuss the housing market crisis", | |
# Also test some telecom topics to ensure we didn't lose that capability | |
"Write about 5G network rollout challenges", | |
"Discuss the latest smartphone technology trends" | |
] | |
# Check if enhanced training data exists | |
try: | |
with open('data/enhanced_train_dataset.json', 'r') as f: | |
enhanced_data = json.load(f) | |
logger.info(f"Enhanced dataset contains {len(enhanced_data)} examples") | |
# Analyze the dataset composition | |
telecom_count = 0 | |
non_telecom_count = 0 | |
for example in enhanced_data: | |
content = str(example).lower() | |
if any(term in content for term in ['telecom', '5g', 'network', 'operator', 'spectrum']): | |
telecom_count += 1 | |
else: | |
non_telecom_count += 1 | |
logger.info(f"Dataset composition: {telecom_count} telecom examples, {non_telecom_count} non-telecom examples") | |
# Check system prompt | |
if enhanced_data: | |
system_prompt = enhanced_data[0]['messages'][0]['content'] | |
logger.info("System prompt preview:") | |
logger.info(system_prompt[:200] + "...") | |
# Check for key improvements | |
improvements = [ | |
"PROVOCATIVE DOOM-LADEN OPENINGS", | |
"SIGNATURE DARK ANALOGIES", | |
"CYNICAL WIT & EXPERTISE", | |
"What could possibly go wrong?" | |
] | |
for improvement in improvements: | |
if improvement in system_prompt: | |
logger.info(f"β Found improvement: {improvement}") | |
else: | |
logger.warning(f"β Missing improvement: {improvement}") | |
except FileNotFoundError: | |
logger.error("Enhanced training dataset not found!") | |
return False | |
# Check if improved validation data exists | |
try: | |
with open('data/improved_val_dataset.json', 'r') as f: | |
val_data = json.load(f) | |
logger.info(f"Improved validation dataset contains {len(val_data)} examples") | |
except FileNotFoundError: | |
logger.warning("Improved validation dataset not found, using original") | |
logger.info("Enhanced model validation completed successfully!") | |
return True | |
def test_model_loading(): | |
"""Test if the model can be loaded properly""" | |
try: | |
# Test base model loading | |
model_name = "HuggingFaceH4/zephyr-7b-beta" | |
logger.info(f"Testing model loading: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
logger.info("β Tokenizer loaded successfully") | |
# Test device detection | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
logger.info("β MPS (Apple Silicon) acceleration available") | |
elif torch.cuda.is_available(): | |
device = "cuda" | |
logger.info("β CUDA acceleration available") | |
else: | |
device = "cpu" | |
logger.info("β Using CPU (no acceleration)") | |
logger.info(f"Will use device: {device}") | |
return True | |
except Exception as e: | |
logger.error(f"Model loading test failed: {e}") | |
return False | |
def main(): | |
"""Run all tests""" | |
logger.info("Starting enhanced model tests...") | |
# Test 1: Enhanced dataset validation | |
logger.info("\n=== Test 1: Enhanced Dataset Validation ===") | |
dataset_ok = test_enhanced_model() | |
# Test 2: Model loading | |
logger.info("\n=== Test 2: Model Loading Test ===") | |
loading_ok = test_model_loading() | |
# Summary | |
logger.info("\n=== Test Summary ===") | |
if dataset_ok and loading_ok: | |
logger.info("β All tests passed! Ready for enhanced training.") | |
return True | |
else: | |
logger.error("β Some tests failed. Check the issues above.") | |
return False | |
if __name__ == "__main__": | |
success = main() | |
exit(0 if success else 1) | |