Spaces:
Sleeping
Sleeping
# flux_train.py | |
import os | |
from collections import OrderedDict | |
from huggingface_hub import whoami | |
import sys | |
sys.path.append("/app/ai-toolkit") # Tell Python to look here | |
from toolkit.job import run_job | |
from toolkit.job import run_job | |
def update_config_push_to_hub(config, push_to_hub: bool, slugged_lora_name: str): | |
config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub | |
if push_to_hub: | |
try: | |
username = whoami()["name"] | |
except Exception: | |
raise RuntimeError( | |
"Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?" | |
) | |
config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}" | |
config["config"]["process"][0]["save"]["hf_private"] = True | |
def build_job(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False): | |
slugged_lora_name = lora_name.lower().replace(" ", "_") | |
job = OrderedDict([ | |
('job', 'extension'), | |
('config', OrderedDict([ | |
('name', lora_name), | |
('process', [ | |
OrderedDict([ | |
('type', 'sd_trainer'), | |
('training_folder', '/tmp/output'), | |
('device', 'cuda:0'), | |
('network', OrderedDict([ | |
('type', 'lora'), | |
('linear', 16), | |
('linear_alpha', 16) | |
])), | |
('save', OrderedDict([ | |
('dtype', 'float16'), | |
('save_every', 250), | |
('max_step_saves_to_keep', 4), | |
# push_to_hub keys added later by update_config_push_to_hub() | |
])), | |
('datasets', [ | |
OrderedDict([ | |
('folder_path', training_path), | |
('caption_ext', 'txt'), | |
('caption_dropout_rate', 0.05), | |
('shuffle_tokens', False), | |
('cache_latents_to_disk', True), | |
('resolution', [512, 768, 1024]) | |
]) | |
]), | |
('train', OrderedDict([ | |
('batch_size', 1), | |
('steps', 2000), | |
('gradient_accumulation_steps', 1), | |
('train_unet', True), | |
('train_text_encoder', False), | |
('content_or_style', 'balanced'), | |
('gradient_checkpointing', True), | |
('noise_scheduler', 'flowmatch'), | |
('optimizer', 'adamw8bit'), | |
('lr', 1e-4), | |
('ema_config', OrderedDict([ | |
('use_ema', True), | |
('ema_decay', 0.99) | |
])), | |
('dtype', 'bf16') | |
])), | |
('model', OrderedDict([ | |
('name_or_path', 'black-forest-labs/FLUX.1-dev'), | |
('is_flux', True), | |
('quantize', True), | |
])), | |
('sample', OrderedDict([ | |
('sampler', 'flowmatch'), | |
('sample_every', 250), | |
('width', 1024), | |
('height', 1024), | |
('prompts', [concept]), | |
('neg', ''), | |
('seed', 42), | |
('walk_seed', True), | |
('guidance_scale', 4), | |
('sample_steps', 20) | |
])) | |
]) | |
]) | |
])), | |
('meta', OrderedDict([ | |
('name', lora_name), | |
('version', '1.0') | |
])) | |
]) | |
# Add push to Hub config if requested | |
update_config_push_to_hub(job, push_to_hub, slugged_lora_name) | |
return job | |
def run_training(concept="ohamlab style", training_path="/tmp/data", lora_name="ohami_filter_autorun", push_to_hub=False): | |
job = build_job(concept, training_path, lora_name, push_to_hub) | |
run_job(job) | |