Spaces:
Sleeping
Sleeping
| """ | |
| TinyLlama Mental Health Fine-Tuning Comparison | |
| Compare base model vs LoRA fine-tuned model responses | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| # ============================================ | |
| # CONFIGURATION | |
| # ============================================ | |
| MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π§ Device: {DEVICE}") | |
| # ============================================ | |
| # LOAD TOKENIZER | |
| # ============================================ | |
| print("π¦ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ============================================ | |
| # LOAD BASE MODEL | |
| # ============================================ | |
| print("π¦ Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| base_model = base_model.to(DEVICE) | |
| base_model.eval() | |
| # ============================================ | |
| # LOAD LORA MODEL | |
| # ============================================ | |
| print("π¦ Loading LoRA model...") | |
| ft_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| ft_model = PeftModel.from_pretrained(ft_model, ".") | |
| ft_model = ft_model.to(DEVICE) | |
| ft_model.eval() | |
| print("β Both models ready!") | |
| # ============================================ | |
| # GENERATION FUNCTION | |
| # ============================================ | |
| def generate_response(prompt, use_finetuning, temperature): | |
| """ | |
| Generate response using either base or fine-tuned model | |
| Args: | |
| prompt: User question | |
| use_finetuning: True = with LoRA, False = pure base model | |
| temperature: Temperature for generation | |
| Returns: | |
| Generated response text | |
| """ | |
| # Select which model to use | |
| model = ft_model if use_finetuning else base_model | |
| print(f"{'π’ Using FINE-TUNED model' if use_finetuning else 'π΄ Using BASE model'}") | |
| print(f"π‘οΈ Temperature: {temperature}") | |
| # Format prompt with chat template | |
| full_text = f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n" | |
| # Tokenize | |
| inputs = tokenizer(full_text, return_tensors="pt").to(DEVICE) | |
| # Generate | |
| try: | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode and extract assistant response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("<|assistant|>")[-1].strip() | |
| return response | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return f"Error: {str(e)}" | |
| # ============================================ | |
| # COMPARISON FUNCTION | |
| # ============================================ | |
| def compare_models(question, temperature): | |
| """ | |
| Generate responses from both models for comparison | |
| Args: | |
| question: User input question | |
| Returns: | |
| Tuple of (base_response, finetuned_response) | |
| """ | |
| if not question.strip(): | |
| return "β οΈ Please enter a question first", "β οΈ Please enter a question first" | |
| # Response from BASE model (without fine-tuning) | |
| base_response = generate_response(question, use_finetuning=False, temperature=temperature) | |
| # Response from FINE-TUNED model (with LoRA) | |
| ft_response = generate_response(question, use_finetuning=True, temperature=temperature) | |
| return base_response, ft_response | |
| # ============================================ | |
| # GRADIO INTERFACE | |
| # ============================================ | |
| with gr.Blocks(title="TinyLlama: Base vs Fine-Tuned", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π§ TinyLlama: Base vs Fine-Tuned Comparison | |
| Compare responses from the **original model** vs the **LoRA fine-tuned version** | |
| trained on mental health conversations. | |
| The fine-tuned model has been trained to provide more empathetic and helpful responses | |
| for mental health-related questions. | |
| """) | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Example: I'm feeling very anxious lately, what can I do?", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="π‘οΈ Temperature", | |
| info="Lower = more focused, Higher = more creative" | |
| ) | |
| btn = gr.Button("π Generate Responses", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π¦ Base Model (No Fine-Tuning)") | |
| gr.Markdown("*Original TinyLlama without any mental health training*") | |
| output_base = gr.Textbox( | |
| label="Response", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### β¨ Fine-Tuned Model (With LoRA)") | |
| gr.Markdown("*Trained on mental health counseling dataset*") | |
| output_ft = gr.Textbox( | |
| label="Response", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| # Connect button to function | |
| btn.click( | |
| fn=compare_models, | |
| inputs=[input_text, temperature_slider], | |
| outputs=[output_base, output_ft] | |
| ) | |
| # Example questions | |
| gr.Examples( | |
| examples=[ | |
| ["I'm feeling worthless and don't know what to do"], | |
| ["How can I deal with anxiety?"], | |
| ["I can't sleep at night because of my thoughts"], | |
| ["I feel like everyone would be better off without me"], | |
| ["What are some ways to manage stress?"], | |
| ], | |
| inputs=input_text, | |
| label="π‘ Try these examples" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### βΉοΈ About This Demo | |
| - **Model**: TinyLlama-1.1B-Chat-v1.0 | |
| - **Fine-Tuning**: LoRA (Low-Rank Adaptation) | |
| - **Dataset**: Mental health counseling conversations | |
| - **Parameters Trained**: ~1.1% of total model (LoRA adapters only) | |
| """) | |
| # ============================================ | |
| # LAUNCH APP | |
| # ============================================ | |
| if __name__ == "__main__": | |
| demo.launch() |