File size: 5,161 Bytes
93eecf2
8a5680c
5c5ab7e
3fc214f
9e28e73
 
 
 
 
 
5c5ab7e
 
236ed28
 
9e28e73
 
5c5ab7e
9e28e73
5c5ab7e
9e28e73
3fc214f
9e28e73
 
 
8a5680c
fc3d288
3fc214f
5c5ab7e
 
 
 
 
 
 
 
 
 
3fc214f
5c5ab7e
 
9e28e73
 
 
 
 
 
5c5ab7e
9e28e73
 
5c5ab7e
9e28e73
 
5c5ab7e
 
 
 
 
9e28e73
5c5ab7e
 
9e28e73
3fc214f
9e28e73
3fc214f
 
5c5ab7e
3fc214f
9e28e73
5c5ab7e
9e28e73
 
5c5ab7e
9e28e73
 
 
35711c6
5c5ab7e
9e28e73
5c5ab7e
 
9e28e73
 
 
5c5ab7e
 
9e28e73
35711c6
5c5ab7e
3fc214f
9e28e73
43a1e87
5c5ab7e
9e28e73
5c5ab7e
 
9e28e73
 
5c5ab7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e28e73
5c5ab7e
 
9c8cf1a
9e28e73
5c5ab7e
9e28e73
5c5ab7e
 
 
 
fc3d288
9e28e73
5c5ab7e
 
 
9e28e73
 
5c5ab7e
 
9e28e73
 
 
5c5ab7e
 
9e28e73
 
5c5ab7e
 
35711c6
9e28e73
5c5ab7e
9e28e73
 
 
93eecf2
9e28e73
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os

# ---------------------------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------------------------
# WARNING: On CPU, 8B models are very heavy. 
# If this crashes, switch to "unsloth/Llama-3.2-3B-Instruct"
BASE_MODEL_ID = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" 
LORA_ADAPTER_ID = "JPQ24/Natural-synthesis-llama-3.2-1b"

# ---------------------------------------------------------------------------
# LOAD MODEL (State Initialization)
# ---------------------------------------------------------------------------
print("System: Initializing CPU Load Sequence...")

# 1. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL_ID, 
    token=os.environ.get("HF_TOKEN")
)

# 2. Load Base Model
# low_cpu_mem_usage=True is critical here to load weights sequentially
print("System: Loading Base Model into RAM...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    device_map="cpu",
    torch_dtype=torch.float32, # Safe default. Use torch.bfloat16 if your Space supports it for speed.
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    token=os.environ.get("HF_TOKEN")
)

# 3. Attach LoRA Adapter
print("System: Attaching LoRA Adapter...")
model = PeftModel.from_pretrained(
    base_model, 
    LORA_ADAPTER_ID, 
    token=os.environ.get("HF_TOKEN")
)

print("System: Ready.")

# ---------------------------------------------------------------------------
# EXECUTION ENGINE
# ---------------------------------------------------------------------------

def run_inference(prompt, use_lora):
    """
    Core computation unit. 
    Accepts state configuration (use_lora) and executes transformation.
    """
    
    # 1. Input Processing
    messages = [{"role": "user", "content": prompt}]
    try:
        inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    except Exception:
        inputs = tokenizer(prompt, return_tensors="pt").input_ids

    # Ensure inputs are on CPU
    inputs = inputs.to("cpu")

    # 2. Generation Config (Conservative for CPU)
    generate_kwargs = dict(
        input_ids=inputs,
        max_new_tokens=100, # Keep short to prevent timeouts
        do_sample=True,
        temperature=0.7,
    )

    # 3. Execution (With Context Switching)
    if not use_lora:
        # CONTEXT A: BASE MODEL
        # We temporarily disable the LoRA connection
        with model.disable_adapter():
            outputs = model.generate(**generate_kwargs)
    else:
        # CONTEXT B: LORA MODEL
        # We use the active adapter
        outputs = model.generate(**generate_kwargs)

    # 4. Output Decoding
    response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
    return response

def sequential_generation(prompt):
    """
    Orchestrator for Sequential Execution.
    Step 1 -> Cleanup -> Step 2
    """
    
    # --- PHASE 1: BASE MODEL ---
    yield "Generating Base Model response... (Please wait)", "Waiting for Base to finish..."
    
    base_result = run_inference(prompt, use_lora=False)
    
    # --- INTERMEDIATE: CLEANUP ---
    # This is a 'heuristic' step to help the CPU breathe. 
    # We force a garbage collection to clear the computation graph from memory.
    gc.collect()
    
    # --- PHASE 2: LORA MODEL ---
    # We yield the first result so the user can read it while the second runs
    yield base_result, "Generating LoRA response... (Please wait)"
    
    lora_result = run_inference(prompt, use_lora=True)
    
    # --- FINAL: COMPLETE ---
    yield base_result, lora_result

# ---------------------------------------------------------------------------
# INTERFACE
# ---------------------------------------------------------------------------
custom_css = """
.container { max-width: 1100px; margin: auto; }
.output-box { height: 400px; overflow-y: scroll; }
"""

with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Sequential Model Comparison (CPU)")
    gr.Markdown(f"**Architecture:** Base (`{BASE_MODEL_ID}`) + Adapter (`{LORA_ADAPTER_ID}`)")
    gr.Markdown("ℹ️ **Process:** This space runs the Base Model first, clears memory, and then runs the LoRA Model.")
    
    with gr.Row():
        input_text = gr.Textbox(label="Prompt", placeholder="e.g. Write a poem about rust...", lines=2)
        submit_btn = gr.Button("Start Comparison", variant="primary")
        
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. Base Model Output")
            output_base = gr.Textbox(label="Base Result", lines=10, interactive=False)
            
        with gr.Column():
            gr.Markdown("### 2. LoRA Model Output")
            output_lora = gr.Textbox(label="Fine-Tuned Result", lines=10, interactive=False)

    submit_btn.click(
        fn=sequential_generation,
        inputs=input_text,
        outputs=[output_base, output_lora]
    )

if __name__ == "__main__":
    demo.launch()