# flux_train.py import os from collections import OrderedDict from huggingface_hub import whoami import sys sys.path.append("/app/ai-toolkit") # Tell Python to look here from toolkit.job import run_job from toolkit.job import run_job def update_config_push_to_hub(config, push_to_hub: bool, slugged_lora_name: str): config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub if push_to_hub: try: username = whoami()["name"] except Exception: raise RuntimeError( "Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?" ) config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}" config["config"]["process"][0]["save"]["hf_private"] = True def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False): slugged_lora_name = lora_name.lower().replace(" ", "_") job = OrderedDict([ ('job', 'extension'), ('config', OrderedDict([ ('name', lora_name), ('process', [ OrderedDict([ ('type', 'sd_trainer'), ('training_folder', '/tmp/output'), ('device', 'cuda:0'), ('network', OrderedDict([ ('type', 'lora'), ('linear', 16), ('linear_alpha', 16) ])), ('save', OrderedDict([ ('dtype', 'float16'), ('save_every', 250), ('max_step_saves_to_keep', 4), # push_to_hub keys added later by update_config_push_to_hub() ])), ('datasets', [ OrderedDict([ ('folder_path', training_path), ('caption_ext', 'txt'), ('caption_dropout_rate', 0.05), ('shuffle_tokens', False), ('cache_latents_to_disk', True), ('resolution', [512, 768, 1024]) ]) ]), ('train', OrderedDict([ ('batch_size', 1), ('steps', 2000), ('gradient_accumulation_steps', 1), ('train_unet', True), ('train_text_encoder', False), ('content_or_style', 'balanced'), ('gradient_checkpointing', True), ('noise_scheduler', 'flowmatch'), ('optimizer', 'adamw8bit'), ('lr', 1e-4), ('ema_config', OrderedDict([ ('use_ema', True), ('ema_decay', 0.99) ])), ('dtype', 'bf16') ])), ('model', OrderedDict([ ('name_or_path', 'black-forest-labs/FLUX.1-dev'), ('is_flux', True), ('quantize', True), ])), ('sample', OrderedDict([ ('sampler', 'flowmatch'), ('sample_every', 250), ('width', 1024), ('height', 1024), ('prompts', [concept]), ('neg', ''), ('seed', 42), ('walk_seed', True), ('guidance_scale', 4), ('sample_steps', 20) ])) ]) ]) ])), ('meta', OrderedDict([ ('name', lora_name), ('version', '1.0') ])) ]) # Add push to Hub config if requested update_config_push_to_hub(job, push_to_hub, slugged_lora_name) return job def run_training(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False): job = build_job(concept, training_path, lora_name, push_to_hub) run_job(job)