oham-lab-train-model / flux_train.py
rahul7star's picture
Update flux_train.py
4235d82 verified
# flux_train.py
import os
from collections import OrderedDict
from huggingface_hub import whoami
import sys
sys.path.append("/app/ai-toolkit") # Tell Python to look here
from toolkit.job import run_job
from toolkit.job import run_job
def update_config_push_to_hub(config, push_to_hub: bool, slugged_lora_name: str):
config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
if push_to_hub:
try:
username = whoami()["name"]
except Exception:
raise RuntimeError(
"Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?"
)
config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
config["config"]["process"][0]["save"]["hf_private"] = True
def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
slugged_lora_name = lora_name.lower().replace(" ", "_")
job = OrderedDict([
('job', 'extension'),
('config', OrderedDict([
('name', lora_name),
('process', [
OrderedDict([
('type', 'sd_trainer'),
('training_folder', '/tmp/output'),
('device', 'cuda:0'),
('network', OrderedDict([
('type', 'lora'),
('linear', 16),
('linear_alpha', 16)
])),
('save', OrderedDict([
('dtype', 'float16'),
('save_every', 250),
('max_step_saves_to_keep', 4),
# push_to_hub keys added later by update_config_push_to_hub()
])),
('datasets', [
OrderedDict([
('folder_path', training_path),
('caption_ext', 'txt'),
('caption_dropout_rate', 0.05),
('shuffle_tokens', False),
('cache_latents_to_disk', True),
('resolution', [512, 768, 1024])
])
]),
('train', OrderedDict([
('batch_size', 1),
('steps', 2000),
('gradient_accumulation_steps', 1),
('train_unet', True),
('train_text_encoder', False),
('content_or_style', 'balanced'),
('gradient_checkpointing', True),
('noise_scheduler', 'flowmatch'),
('optimizer', 'adamw8bit'),
('lr', 1e-4),
('ema_config', OrderedDict([
('use_ema', True),
('ema_decay', 0.99)
])),
('dtype', 'bf16')
])),
('model', OrderedDict([
('name_or_path', 'black-forest-labs/FLUX.1-dev'),
('is_flux', True),
('quantize', True),
])),
('sample', OrderedDict([
('sampler', 'flowmatch'),
('sample_every', 250),
('width', 1024),
('height', 1024),
('prompts', [concept]),
('neg', ''),
('seed', 42),
('walk_seed', True),
('guidance_scale', 4),
('sample_steps', 20)
]))
])
])
])),
('meta', OrderedDict([
('name', lora_name),
('version', '1.0')
]))
])
# Add push to Hub config if requested
update_config_push_to_hub(job, push_to_hub, slugged_lora_name)
return job
def run_training(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
job = build_job(concept, training_path, lora_name, push_to_hub)
run_job(job)