Cognitive_Lora / app.py
JPQ24's picture
Update app.py
236ed28 verified
import gradio as gr
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
# ---------------------------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------------------------
# WARNING: On CPU, 8B models are very heavy.
# If this crashes, switch to "unsloth/Llama-3.2-3B-Instruct"
BASE_MODEL_ID = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
LORA_ADAPTER_ID = "JPQ24/Natural-synthesis-llama-3.2-1b"
# ---------------------------------------------------------------------------
# LOAD MODEL (State Initialization)
# ---------------------------------------------------------------------------
print("System: Initializing CPU Load Sequence...")
# 1. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=os.environ.get("HF_TOKEN")
)
# 2. Load Base Model
# low_cpu_mem_usage=True is critical here to load weights sequentially
print("System: Loading Base Model into RAM...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="cpu",
torch_dtype=torch.float32, # Safe default. Use torch.bfloat16 if your Space supports it for speed.
low_cpu_mem_usage=True,
trust_remote_code=True,
token=os.environ.get("HF_TOKEN")
)
# 3. Attach LoRA Adapter
print("System: Attaching LoRA Adapter...")
model = PeftModel.from_pretrained(
base_model,
LORA_ADAPTER_ID,
token=os.environ.get("HF_TOKEN")
)
print("System: Ready.")
# ---------------------------------------------------------------------------
# EXECUTION ENGINE
# ---------------------------------------------------------------------------
def run_inference(prompt, use_lora):
"""
Core computation unit.
Accepts state configuration (use_lora) and executes transformation.
"""
# 1. Input Processing
messages = [{"role": "user", "content": prompt}]
try:
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
except Exception:
inputs = tokenizer(prompt, return_tensors="pt").input_ids
# Ensure inputs are on CPU
inputs = inputs.to("cpu")
# 2. Generation Config (Conservative for CPU)
generate_kwargs = dict(
input_ids=inputs,
max_new_tokens=100, # Keep short to prevent timeouts
do_sample=True,
temperature=0.7,
)
# 3. Execution (With Context Switching)
if not use_lora:
# CONTEXT A: BASE MODEL
# We temporarily disable the LoRA connection
with model.disable_adapter():
outputs = model.generate(**generate_kwargs)
else:
# CONTEXT B: LORA MODEL
# We use the active adapter
outputs = model.generate(**generate_kwargs)
# 4. Output Decoding
response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
return response
def sequential_generation(prompt):
"""
Orchestrator for Sequential Execution.
Step 1 -> Cleanup -> Step 2
"""
# --- PHASE 1: BASE MODEL ---
yield "Generating Base Model response... (Please wait)", "Waiting for Base to finish..."
base_result = run_inference(prompt, use_lora=False)
# --- INTERMEDIATE: CLEANUP ---
# This is a 'heuristic' step to help the CPU breathe.
# We force a garbage collection to clear the computation graph from memory.
gc.collect()
# --- PHASE 2: LORA MODEL ---
# We yield the first result so the user can read it while the second runs
yield base_result, "Generating LoRA response... (Please wait)"
lora_result = run_inference(prompt, use_lora=True)
# --- FINAL: COMPLETE ---
yield base_result, lora_result
# ---------------------------------------------------------------------------
# INTERFACE
# ---------------------------------------------------------------------------
custom_css = """
.container { max-width: 1100px; margin: auto; }
.output-box { height: 400px; overflow-y: scroll; }
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# Sequential Model Comparison (CPU)")
gr.Markdown(f"**Architecture:** Base (`{BASE_MODEL_ID}`) + Adapter (`{LORA_ADAPTER_ID}`)")
gr.Markdown("ℹ️ **Process:** This space runs the Base Model first, clears memory, and then runs the LoRA Model.")
with gr.Row():
input_text = gr.Textbox(label="Prompt", placeholder="e.g. Write a poem about rust...", lines=2)
submit_btn = gr.Button("Start Comparison", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Base Model Output")
output_base = gr.Textbox(label="Base Result", lines=10, interactive=False)
with gr.Column():
gr.Markdown("### 2. LoRA Model Output")
output_lora = gr.Textbox(label="Fine-Tuned Result", lines=10, interactive=False)
submit_btn.click(
fn=sequential_generation,
inputs=input_text,
outputs=[output_base, output_lora]
)
if __name__ == "__main__":
demo.launch()