File size: 7,812 Bytes
ba3e817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
import gradio as gr
import spaces
import re
from peft import PeftModel

# Load the base model
try:
    base_model = AutoModelForCausalLM.from_pretrained(
        "openai/gpt-oss-20b",
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="kernels-community/vllm-flash-attention3"
    )
    tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    
    # Load the LoRA adapter
    try:
        model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner")
        print("✅ LoRA model loaded successfully!")
    except Exception as lora_error:
        print(f"⚠️ LoRA adapter failed to load: {lora_error}")
        print("🔄 Falling back to base model...")
        model = base_model
        
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise e

def format_conversation_history(chat_history):
    messages = []
    for item in chat_history:
        role = item["role"]
        content = item["content"]
        if isinstance(content, list):
            content = content[0]["text"] if content and "text" in content[0] else str(content)
        messages.append({"role": role, "content": content})
    return messages

def create_harmony_prompt(messages, reasoning_level="medium"):
    """
    Create a proper Harmony format prompt for GPT-OSS-20B
    Based on the Harmony format from https://github.com/openai/harmony
    """
    # Start with system message in Harmony format
    system_content = f"""You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-01-28

Reasoning: {reasoning_level}

# Valid channels: analysis, commentary, final. Channel must be included for every message."""
    
    # Build the prompt in Harmony format
    prompt_parts = []
    
    # Add system message
    prompt_parts.append(f"<|start|>system<|message|>{system_content}<|end|>")
    
    # Add conversation messages
    for message in messages:
        role = message["role"]
        content = message["content"]
        
        if role == "system":
            # Skip system messages as we already added the main one
            continue
        elif role == "user":
            prompt_parts.append(f"<|start|>user<|message|>{content}<|end|>")
        elif role == "assistant":
            prompt_parts.append(f"<|start|>assistant<|message|>{content}<|end|>")
    
    # Add the generation prompt
    prompt_parts.append("<|start|>assistant")
    
    return "\n".join(prompt_parts)

@spaces.GPU(duration=60)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
    new_message = {"role": "user", "content": input_data}
    system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
    processed_history = format_conversation_history(chat_history)
    messages = system_message + processed_history + [new_message]
    
    # Extract reasoning level from system prompt
    reasoning_level = "medium"
    if "reasoning:" in system_prompt.lower():
        if "high" in system_prompt.lower():
            reasoning_level = "high"
        elif "low" in system_prompt.lower():
            reasoning_level = "low"
    
    # Create Harmony format prompt
    prompt = create_harmony_prompt(messages, reasoning_level)
    
    # Create streamer for proper streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Prepare generation kwargs
    generation_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "pad_token_id": tokenizer.eos_token_id,
        "streamer": streamer,
        "use_cache": True
    }
    
    # Tokenize input using the Harmony format
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
    thread.start()
    
    # Stream the response and parse Harmony format
    current_channel = None
    current_content = ""
    thinking = ""
    final = ""
    
    for chunk in streamer:
        current_content += chunk
        
        # Parse Harmony format channels
        # Look for channel markers like <|channel|>analysis, <|channel|>commentary, <|channel|>final
        if "<|channel|>" in current_content:
            # Extract channel and content
            parts = current_content.split("<|channel|>")
            if len(parts) >= 2:
                channel_part = parts[1]
                if channel_part.startswith("analysis"):
                    current_channel = "analysis"
                    content_start = channel_part.find("<|message|>")
                    if content_start != -1:
                        content = channel_part[content_start + 10:]  # length of "<|message|>"
                        thinking += content
                elif channel_part.startswith("commentary"):
                    current_channel = "commentary"
                    content_start = channel_part.find("<|message|>")
                    if content_start != -1:
                        content = channel_part[content_start + 10:]
                        thinking += content
                elif channel_part.startswith("final"):
                    current_channel = "final"
                    content_start = channel_part.find("<|message|>")
                    if content_start != -1:
                        content = channel_part[content_start + 10:]
                        final += content
        
        # Clean up the content for display
        clean_thinking = re.sub(r'^analysis\s*', '', thinking).strip()
        clean_final = final.strip()
        
        # Format for display
        if clean_thinking or clean_final:
            formatted = f"<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
            yield formatted

demo = gr.ChatInterface(
    fn=generate_response,
    additional_inputs=[
        gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
        gr.Textbox(
            label="System Prompt",
            value="You are a helpful assistant. Reasoning: medium",
            lines=4,
            placeholder="Change system prompt"
        ),
        gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
        gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
        gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
        gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
    ],
    examples=[
        [{"text": "Explain Newton laws clearly and concisely"}],
        [{"text": "Write a Python function to calculate the Fibonacci sequence"}],
        [{"text": "What are the benefits of open weight AI models"}],
    ],
    cache_examples=False,
    type="messages",
    description="""
# 🙋🏻‍♂️Welcome to 🌟Tonic's gpt-oss-20b Multilingual Reasoner Demo !
Wait couple of seconds initially. You can adjust reasoning level in the system prompt like "Reasoning: high.
This version uses the proper Harmony format for better generation quality.
    """,
    fill_height=True,
    textbox=gr.Textbox(
        label="Query Input",
        placeholder="Type your prompt"
    ),
    stop_btn="Stop Generation",
    multimodal=False,
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    demo.launch(share=True)