File size: 5,382 Bytes
599c2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Test the enhanced model specifically for Iain Morris style improvements
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def test_enhanced_style():
    """Test the enhanced model for specific Iain Morris style elements"""
    
    # Load the enhanced model
    logger.info("Loading enhanced model...")
    
    base_model_name = "HuggingFaceH4/zephyr-7b-beta"
    model_path = "models/iain-morris-model-enhanced"
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        device_map="auto" if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    
    # Load fine-tuned model
    model = PeftModel.from_pretrained(base_model, model_path)
    
    # Test prompts designed to trigger Iain Morris style
    test_prompts = [
        "Write about the disaster of modern smartphone launches",
        "Describe the catastrophe of online shopping during Black Friday", 
        "Write about the train wreck of modern customer service",
        "Discuss the collision course of social media and democracy",
        "Write about cryptocurrency - what could possibly go wrong?",
        "Describe the kiss of death for traditional retail",
        "Write about the explosion of subscription services"
    ]
    
    # System prompt (should match the enhanced version)
    system_prompt = """You are Iain Morris, a razor-sharp British writer with zero tolerance for BS. Your writing style is distinctive for:

PROVOCATIVE DOOM-LADEN OPENINGS:
- Always lead with conflict, failure, or impending disaster
- Use visceral, dramatic scenarios that grab readers by the throat
- Frame mundane topics as battles, collisions, or catastrophes
- Open with vivid imagery that establishes immediate tension

SIGNATURE DARK ANALOGIES:
- Compare situations to train wrecks, explosions, collisions
- Use physical, visceral metaphors for abstract problems
- Reference pop culture disasters and failures
- Turn simple concepts into dramatic, often dark imagery

CYNICAL WIT & EXPERTISE:
- Deliver insights with biting sarcasm and parenthetical snark
- Assume readers are intelligent but skeptical
- Quote figures, then immediately undercut them
- Use technical knowledge as a weapon of wit

DISTINCTIVE PHRASES:
- "What could possibly go wrong?"
- "kiss of death," "train wreck," "collision course"
- Parenthetical asides for extra snark
- British expressions and dry humor

Write with the assumption that everything is either failing, about to fail, or succeeding despite obvious flaws."""

    logger.info("Testing enhanced style elements...")
    
    for i, prompt in enumerate(test_prompts, 1):
        logger.info(f"\n--- Test {i}: {prompt} ---")
        
        # Format the conversation
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]
        
        # Apply chat template
        formatted_prompt = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Tokenize
        inputs = tokenizer(formatted_prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = inputs.to("cuda")
        elif torch.backends.mps.is_available():
            inputs = inputs.to("mps")
        
        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=300,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
        # Decode response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract just the assistant's response
        if "<|assistant|>" in response:
            response = response.split("<|assistant|>")[-1].strip()
        
        print(f"\nResponse:\n{response[:500]}...")
        
        # Check for style elements
        style_elements = {
            "doom_opening": any(word in response.lower()[:100] for word in ["disaster", "catastrophe", "collapse", "meltdown", "crisis", "nightmare"]),
            "dark_analogies": any(phrase in response.lower() for phrase in ["train wreck", "collision", "explosion", "crash", "burning", "sinking"]),
            "signature_phrase": "what could possibly go wrong" in response.lower(),
            "parenthetical_snark": "(" in response and ")" in response,
            "cynical_tone": any(word in response.lower() for word in ["irony", "meanwhile", "of course", "naturally", "predictably"])
        }
        
        print(f"\nStyle Analysis:")
        for element, found in style_elements.items():
            status = "✓" if found else "✗"
            print(f"  {status} {element.replace('_', ' ').title()}")
        
        print("-" * 60)

if __name__ == "__main__":
    test_enhanced_style()