# train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion) import os, argparse, torch from diffusers import StableDiffusionPipeline, DDPMScheduler from diffusers import DreamBoothLoraTrainer from peft import LoraConfig from accelerate import Accelerator parser = argparse.ArgumentParser() parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录 args = parser.parse_args() # LoRA 配置(兼容 QLoRA) lora_cfg = LoraConfig( r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"] ) # 4-bit 量化加载 SD-1.5 pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, load_in_4bit=True, quantization_config={ "bnb_4bit_compute_dtype": torch.float16, "bnb_4bit_use_double_quant": True, "bnb_4bit_quant_type": "nf4" }, ) # DreamBooth LoRA Trainer trainer = DreamBoothLoraTrainer( instance_data_root=args.data, instance_prompt="a photo of an urbanad nyc", lora_config=lora_cfg, output_dir="./nyc-ad-model", max_train_steps=400, train_batch_size=1, gradient_checkpointing=True, ) # DeepSpeed ZeRO-3 加速 / 显存拆分 accelerator = Accelerator( mixed_precision="fp16", deepspeed_config="./ds_config_zero3.json" # 需提前放置 ) # 开始训练 trainer.train(accelerator)