rl_training / app.py
canonica1's picture
Update app.py
77ef2bb verified
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
import wandb
wandb.login(key="2fa14e3cc1db3ff6c0d83973c3b7b9d152a73b70")
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)
import os
os.environ["FLASH_ATTENTION_FORCE_DISABLED"] = "1"
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load LoRA
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=32,
target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())
# Reward function
ideal_length = 50
def reward_len(completions, **kwargs):
return [-abs(ideal_length - len(completion)) for completion in completions]
training_args = GRPOConfig(
output_dir="GRPO",
learning_rate=2e-5,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
max_prompt_length=512,
max_completion_length=96,
num_generations=8,
num_train_epochs=1,
report_to=["wandb"],
remove_unused_columns=False,
logging_steps=1,
bf16=False,
fp16=True, # если есть GPU
optim="adamw_torch_fused", # НЕ "adamw_8bit"
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[reward_len],
args=training_args,
train_dataset=dataset["train"],
)
# Train model
wandb.init(project="GRPO")
trainer.train()