Spaces:
Running
Running
# train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion) | |
import os, argparse, torch | |
from diffusers import StableDiffusionPipeline, DDPMScheduler | |
from diffusers import DreamBoothLoraTrainer | |
from peft import LoraConfig | |
from accelerate import Accelerator | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录 | |
args = parser.parse_args() | |
# LoRA 配置(兼容 QLoRA) | |
lora_cfg = LoraConfig( | |
r=8, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
target_modules=["q_proj", "v_proj"] | |
) | |
# 4-bit 量化加载 SD-1.5 | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16, | |
load_in_4bit=True, | |
quantization_config={ | |
"bnb_4bit_compute_dtype": torch.float16, | |
"bnb_4bit_use_double_quant": True, | |
"bnb_4bit_quant_type": "nf4" | |
}, | |
) | |
# DreamBooth LoRA Trainer | |
trainer = DreamBoothLoraTrainer( | |
instance_data_root=args.data, | |
instance_prompt="a photo of an urbanad nyc", | |
lora_config=lora_cfg, | |
output_dir="./nyc-ad-model", | |
max_train_steps=400, | |
train_batch_size=1, | |
gradient_checkpointing=True, | |
) | |
# DeepSpeed ZeRO-3 加速 / 显存拆分 | |
accelerator = Accelerator( | |
mixed_precision="fp16", | |
deepspeed_config="./ds_config_zero3.json" # 需提前放置 | |
) | |
# 开始训练 | |
trainer.train(accelerator) |