AdGPT / ppo_tune.py
goodmodeler's picture
ADD: LLM techs
696ae63
raw
history blame
883 Bytes
from trl import PPOTrainer, PPOConfig
from peft import PeftModel
import torch, random, json, glob
from diffusers import StableDiffusionPipeline
from reward_model import CLIPModel, CLIPProcessor
rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
prompts=[l.strip() for l in open("prompt.txt")]
for step in range(500):
p=random.choice(prompts)
img=pipe(p,num_inference_steps=20).images[0]
reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
trainer.step(prompts=[p], rewards=[reward])
pipe.save_pretrained("nyc-ad-model-rlhf")