Spaces:
Sleeping
Sleeping
File size: 4,312 Bytes
28310b7 4235d82 bd4a5d2 4235d82 28310b7 4235d82 bd4a5d2 4235d82 bd4a5d2 4235d82 28310b7 4235d82 28310b7 4235d82 28310b7 4235d82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# flux_train.py
import os
from collections import OrderedDict
from huggingface_hub import whoami
import sys
sys.path.append("/app/ai-toolkit") # Tell Python to look here
from toolkit.job import run_job
from toolkit.job import run_job
def update_config_push_to_hub(config, push_to_hub: bool, slugged_lora_name: str):
config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
if push_to_hub:
try:
username = whoami()["name"]
except Exception:
raise RuntimeError(
"Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?"
)
config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
config["config"]["process"][0]["save"]["hf_private"] = True
def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
slugged_lora_name = lora_name.lower().replace(" ", "_")
job = OrderedDict([
('job', 'extension'),
('config', OrderedDict([
('name', lora_name),
('process', [
OrderedDict([
('type', 'sd_trainer'),
('training_folder', '/tmp/output'),
('device', 'cuda:0'),
('network', OrderedDict([
('type', 'lora'),
('linear', 16),
('linear_alpha', 16)
])),
('save', OrderedDict([
('dtype', 'float16'),
('save_every', 250),
('max_step_saves_to_keep', 4),
# push_to_hub keys added later by update_config_push_to_hub()
])),
('datasets', [
OrderedDict([
('folder_path', training_path),
('caption_ext', 'txt'),
('caption_dropout_rate', 0.05),
('shuffle_tokens', False),
('cache_latents_to_disk', True),
('resolution', [512, 768, 1024])
])
]),
('train', OrderedDict([
('batch_size', 1),
('steps', 2000),
('gradient_accumulation_steps', 1),
('train_unet', True),
('train_text_encoder', False),
('content_or_style', 'balanced'),
('gradient_checkpointing', True),
('noise_scheduler', 'flowmatch'),
('optimizer', 'adamw8bit'),
('lr', 1e-4),
('ema_config', OrderedDict([
('use_ema', True),
('ema_decay', 0.99)
])),
('dtype', 'bf16')
])),
('model', OrderedDict([
('name_or_path', 'black-forest-labs/FLUX.1-dev'),
('is_flux', True),
('quantize', True),
])),
('sample', OrderedDict([
('sampler', 'flowmatch'),
('sample_every', 250),
('width', 1024),
('height', 1024),
('prompts', [concept]),
('neg', ''),
('seed', 42),
('walk_seed', True),
('guidance_scale', 4),
('sample_steps', 20)
]))
])
])
])),
('meta', OrderedDict([
('name', lora_name),
('version', '1.0')
]))
])
# Add push to Hub config if requested
update_config_push_to_hub(job, push_to_hub, slugged_lora_name)
return job
def run_training(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False):
job = build_job(concept, training_path, lora_name, push_to_hub)
run_job(job)
|