Spaces:
Running
Running
Commit
·
48d6121
1
Parent(s):
6fff6df
time to test iamge conditioning
Browse files- docs/gradio/external_plugin--gradio_modal.md +108 -0
- vms/config.py +183 -195
- vms/ui/app_ui.py +19 -20
- vms/ui/project/services/training.py +26 -16
- vms/ui/project/tabs/train_tab.py +151 -298
docs/gradio/external_plugin--gradio_modal.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Description du projet
|
2 |
+
---------------------
|
3 |
+
|
4 |
+
`gradio_modal`
|
5 |
+
==============
|
6 |
+
|
7 |
+
[](https://pypi.org/project/gradio_modal/)
|
8 |
+
|
9 |
+
A popup modal component
|
10 |
+
|
11 |
+
Installation
|
12 |
+
------------
|
13 |
+
|
14 |
+
pip install gradio\_modal
|
15 |
+
|
16 |
+
Usage
|
17 |
+
-----
|
18 |
+
|
19 |
+
import gradio as gr
|
20 |
+
from gradio\_modal import Modal
|
21 |
+
|
22 |
+
with gr.Blocks() as demo:
|
23 |
+
with gr.Tab("Tab 1"):
|
24 |
+
text\_1 \= gr.Textbox(label\="Input 1")
|
25 |
+
text\_2 \= gr.Textbox(label\="Input 2")
|
26 |
+
text\_1.submit(lambda x:x, text\_1, text\_2)
|
27 |
+
show\_btn \= gr.Button("Show Modal")
|
28 |
+
show\_btn2 \= gr.Button("Show Modal 2")
|
29 |
+
gr.Examples(
|
30 |
+
\[\["Text 1", "Text 2"\], \["Text 3", "Text 4"\]\],
|
31 |
+
inputs\=\[text\_1, text\_2\],
|
32 |
+
)
|
33 |
+
with gr.Tab("Tab 2"):
|
34 |
+
gr.Markdown("This is tab 2")
|
35 |
+
with Modal(visible\=False) as modal:
|
36 |
+
for i in range(5):
|
37 |
+
gr.Markdown("Hello world!")
|
38 |
+
with Modal(visible\=False) as modal2:
|
39 |
+
for i in range(100):
|
40 |
+
gr.Markdown("Hello world!")
|
41 |
+
show\_btn.click(lambda: Modal(visible\=True), None, modal)
|
42 |
+
show\_btn2.click(lambda: Modal(visible\=True), None, modal2)
|
43 |
+
|
44 |
+
if \_\_name\_\_ \== "\_\_main\_\_":
|
45 |
+
demo.launch()
|
46 |
+
|
47 |
+
`Modal`
|
48 |
+
-------
|
49 |
+
|
50 |
+
### Initialization
|
51 |
+
|
52 |
+
name
|
53 |
+
|
54 |
+
type
|
55 |
+
|
56 |
+
default
|
57 |
+
|
58 |
+
description
|
59 |
+
|
60 |
+
`visible`
|
61 |
+
|
62 |
+
bool
|
63 |
+
|
64 |
+
`False`
|
65 |
+
|
66 |
+
If False, modal will be hidden.
|
67 |
+
|
68 |
+
`elem_id`
|
69 |
+
|
70 |
+
str | None
|
71 |
+
|
72 |
+
`None`
|
73 |
+
|
74 |
+
An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
75 |
+
|
76 |
+
`elem_classes`
|
77 |
+
|
78 |
+
list\[str\] | str | None
|
79 |
+
|
80 |
+
`None`
|
81 |
+
|
82 |
+
An optional string or list of strings that are assigned as the class of this component in the HTML DOM. Can be used for targeting CSS styles.
|
83 |
+
|
84 |
+
`allow_user_close`
|
85 |
+
|
86 |
+
bool
|
87 |
+
|
88 |
+
`True`
|
89 |
+
|
90 |
+
If True, user can close the modal (by clicking outside, clicking the X, or the escape key).
|
91 |
+
|
92 |
+
`render`
|
93 |
+
|
94 |
+
bool
|
95 |
+
|
96 |
+
`True`
|
97 |
+
|
98 |
+
If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
|
99 |
+
|
100 |
+
### Events
|
101 |
+
|
102 |
+
name
|
103 |
+
|
104 |
+
description
|
105 |
+
|
106 |
+
`blur`
|
107 |
+
|
108 |
+
This listener is triggered when the Modal is unfocused/blurred.
|
vms/config.py
CHANGED
@@ -304,39 +304,71 @@ DEFAULT_VALIDATION_WIDTH = 768
|
|
304 |
DEFAULT_VALIDATION_NB_FRAMES = 49
|
305 |
DEFAULT_VALIDATION_FRAMERATE = 8
|
306 |
|
307 |
-
#
|
308 |
-
#
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
# it is important that the resolution buckets properly cover the training dataset,
|
318 |
# or else that we exclude from the dataset videos that are out of this range
|
319 |
# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
|
320 |
|
321 |
-
NB_FRAMES_1
|
322 |
-
NB_FRAMES_9
|
323 |
-
NB_FRAMES_17
|
324 |
-
NB_FRAMES_33
|
325 |
-
NB_FRAMES_49
|
326 |
-
NB_FRAMES_65
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
329 |
NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
|
|
|
330 |
NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
|
|
|
331 |
NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
|
332 |
-
NB_FRAMES_161
|
333 |
NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
|
334 |
NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
|
|
|
|
|
|
|
335 |
NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
|
|
|
|
|
|
|
336 |
NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
|
337 |
-
|
338 |
-
# can we crank it and put more frames in here?
|
339 |
-
|
340 |
NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
|
341 |
NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
|
342 |
NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
|
@@ -347,199 +379,155 @@ NB_FRAMES_369 = 8 * 46 + 1 # 368 + 1
|
|
347 |
NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
|
348 |
NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
(
|
362 |
-
(
|
363 |
-
(
|
364 |
-
(
|
365 |
-
(
|
366 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
]
|
368 |
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
(
|
374 |
-
(
|
375 |
-
(
|
376 |
-
(
|
377 |
-
(
|
378 |
-
(
|
379 |
-
(
|
380 |
-
(NB_FRAMES_97,
|
381 |
-
(
|
382 |
-
(
|
383 |
-
(
|
384 |
-
(
|
385 |
-
(
|
386 |
-
(
|
387 |
-
(
|
388 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
]
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
400 |
"learning_rate": 2e-5,
|
401 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
402 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
403 |
"flow_weighting_scheme": "none",
|
404 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
405 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
406 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
407 |
-
},
|
408 |
-
"LTX-Video (normal)": {
|
409 |
-
"model_type": "ltx_video",
|
410 |
-
"training_type": "lora",
|
411 |
"lora_rank": DEFAULT_LORA_RANK_STR,
|
412 |
-
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
413 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
414 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
415 |
-
"learning_rate": DEFAULT_LEARNING_RATE,
|
416 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
417 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
418 |
-
"flow_weighting_scheme": "none",
|
419 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
420 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
421 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
422 |
},
|
423 |
-
"
|
424 |
-
"
|
425 |
-
"
|
426 |
-
"lora_rank": "
|
427 |
-
"lora_alpha":
|
428 |
-
"
|
429 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
431 |
-
"
|
432 |
-
"
|
433 |
-
"
|
434 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
435 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
436 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
437 |
},
|
438 |
-
"
|
439 |
-
"model_type": "ltx_video",
|
440 |
-
"training_type": "full-finetune",
|
441 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
442 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
443 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
444 |
-
"
|
445 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
446 |
-
"flow_weighting_scheme": "logit_normal",
|
447 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
448 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
449 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
450 |
-
},
|
451 |
-
"Wan-2.1-T2V (normal)": {
|
452 |
-
"model_type": "wan",
|
453 |
-
"training_type": "lora",
|
454 |
-
"lora_rank": "32",
|
455 |
-
"lora_alpha": "32",
|
456 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
457 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
458 |
-
"learning_rate": 5e-5,
|
459 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
460 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
461 |
-
"flow_weighting_scheme": "logit_normal",
|
462 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
463 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
464 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
465 |
},
|
466 |
-
"
|
467 |
-
"model_type": "wan",
|
468 |
-
"training_type": "lora",
|
469 |
-
"lora_rank": "64",
|
470 |
-
"lora_alpha": "64",
|
471 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
472 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
473 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
474 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
475 |
-
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
476 |
-
"flow_weighting_scheme": "logit_normal",
|
477 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
478 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
479 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
480 |
-
},
|
481 |
-
"Wan-2.1-I2V (Control LoRA)": {
|
482 |
-
"model_type": "wan",
|
483 |
-
"training_type": "control-lora",
|
484 |
-
"lora_rank": "32",
|
485 |
-
"lora_alpha": "32",
|
486 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
487 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
488 |
-
"learning_rate": 5e-5,
|
489 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
490 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
491 |
"flow_weighting_scheme": "logit_normal",
|
492 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
493 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
494 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
495 |
-
"control_type": "custom",
|
496 |
-
"train_qk_norm": True,
|
497 |
-
"frame_conditioning_type": "index",
|
498 |
-
"frame_conditioning_index": 0,
|
499 |
-
"frame_conditioning_concatenate_mask": True,
|
500 |
-
"description": "Image-conditioned video generation with LoRA adapters"
|
501 |
-
},
|
502 |
-
"LTX-Video (Control LoRA)": {
|
503 |
-
"model_type": "ltx_video",
|
504 |
-
"training_type": "control-lora",
|
505 |
"lora_rank": "128",
|
506 |
"lora_alpha": "128",
|
507 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
508 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
509 |
-
"learning_rate": DEFAULT_LEARNING_RATE,
|
510 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
511 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
512 |
-
"flow_weighting_scheme": "logit_normal",
|
513 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
514 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
515 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
516 |
"control_type": "custom",
|
517 |
"train_qk_norm": True,
|
518 |
"frame_conditioning_type": "index",
|
519 |
"frame_conditioning_index": 0,
|
520 |
-
"frame_conditioning_concatenate_mask": True
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
},
|
523 |
-
"
|
524 |
-
"
|
525 |
-
"
|
526 |
-
"lora_rank": "
|
527 |
-
"lora_alpha": "
|
528 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
529 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
530 |
-
"learning_rate": 2e-5,
|
531 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
532 |
-
"training_buckets": SMALL_TRAINING_BUCKETS,
|
533 |
-
"flow_weighting_scheme": "none",
|
534 |
-
"num_gpus": DEFAULT_NUM_GPUS,
|
535 |
-
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
536 |
-
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
537 |
"control_type": "custom",
|
538 |
"train_qk_norm": True,
|
539 |
"frame_conditioning_type": "index",
|
540 |
"frame_conditioning_index": 0,
|
541 |
-
"frame_conditioning_concatenate_mask": True
|
542 |
-
"description": "Image-conditioned video generation with HunyuanVideo and LoRA adapters"
|
543 |
}
|
544 |
}
|
545 |
|
@@ -567,7 +555,7 @@ class TrainingConfig:
|
|
567 |
caption_column: str = "prompts.txt"
|
568 |
|
569 |
id_token: Optional[str] = None
|
570 |
-
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda:
|
571 |
video_reshape_mode: str = "center"
|
572 |
caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
|
573 |
caption_dropout_technique: str = "empty"
|
@@ -632,7 +620,7 @@ class TrainingConfig:
|
|
632 |
gradient_accumulation_steps=1,
|
633 |
lora_rank=DEFAULT_LORA_RANK,
|
634 |
lora_alpha=DEFAULT_LORA_ALPHA,
|
635 |
-
video_resolution_buckets=buckets or
|
636 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
637 |
flow_weighting_scheme="none", # Hunyuan specific
|
638 |
training_type="lora"
|
@@ -654,7 +642,7 @@ class TrainingConfig:
|
|
654 |
gradient_accumulation_steps=4,
|
655 |
lora_rank=DEFAULT_LORA_RANK,
|
656 |
lora_alpha=DEFAULT_LORA_ALPHA,
|
657 |
-
video_resolution_buckets=buckets or
|
658 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
659 |
flow_weighting_scheme="logit_normal", # LTX specific
|
660 |
training_type="lora"
|
@@ -674,7 +662,7 @@ class TrainingConfig:
|
|
674 |
gradient_checkpointing=True,
|
675 |
id_token=None,
|
676 |
gradient_accumulation_steps=1,
|
677 |
-
video_resolution_buckets=buckets or
|
678 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
679 |
flow_weighting_scheme="logit_normal", # LTX specific
|
680 |
training_type="full-finetune"
|
@@ -697,7 +685,7 @@ class TrainingConfig:
|
|
697 |
lora_rank=32,
|
698 |
lora_alpha=32,
|
699 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
700 |
-
video_resolution_buckets=buckets or
|
701 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
702 |
flow_weighting_scheme="logit_normal", # Wan specific
|
703 |
training_type="lora"
|
|
|
304 |
DEFAULT_VALIDATION_NB_FRAMES = 49
|
305 |
DEFAULT_VALIDATION_FRAMERATE = 8
|
306 |
|
307 |
+
# you should use resolutions that are powers of 8
|
308 |
+
# using a 16:9 ratio is also super-recommended
|
309 |
+
|
310 |
+
# SD
|
311 |
+
SD_16_9_W = 1024 # 8*128
|
312 |
+
SD_16_9_H = 576 # 8*72
|
313 |
+
SD_9_16_W = 576 # 8*72
|
314 |
+
SD_9_16_H = 1024 # 8*128
|
315 |
+
|
316 |
+
# MD (720p)
|
317 |
+
MD_16_9_W = 1280 # 8*160
|
318 |
+
MD_16_9_H = 720 # 8*90
|
319 |
+
MD_9_16_W = 720 # 8*90
|
320 |
+
MD_9_16_H = 1280 # 8*160
|
321 |
+
|
322 |
+
# HD (1080p)
|
323 |
+
HD_16_9_W = 1920 # 8*240
|
324 |
+
HD_16_9_H = 1080 # 8*135
|
325 |
+
HD_9_16_W = 1080 # 8*135
|
326 |
+
HD_9_16_H = 1920 # 8*240
|
327 |
+
|
328 |
+
# QHD (2K)
|
329 |
+
QHD_16_9_W = 2160 # 8*270
|
330 |
+
QHD_16_9_H = 1440 # 8*180
|
331 |
+
QHD_9_16_W = 1440 # 8*180
|
332 |
+
QHD_9_16_H = 2160 # 8*270
|
333 |
+
|
334 |
+
# UHD (4K)
|
335 |
+
UHD_16_9_W = 3840 # 8*480
|
336 |
+
UHD_16_9_H = 2160 # 8*270
|
337 |
+
UHD_9_16_W = 2160 # 8*270
|
338 |
+
UHD_9_16_H = 3840 # 8*480
|
339 |
+
|
340 |
# it is important that the resolution buckets properly cover the training dataset,
|
341 |
# or else that we exclude from the dataset videos that are out of this range
|
342 |
# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
|
343 |
|
344 |
+
NB_FRAMES_1 = 1 # 1
|
345 |
+
NB_FRAMES_9 = 8 + 1 # 8 + 1
|
346 |
+
NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
|
347 |
+
NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
|
348 |
+
NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
|
349 |
+
NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
|
350 |
+
NB_FRAMES_73 = 8 * 9 + 1 # 72 + 1
|
351 |
+
NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
|
352 |
+
NB_FRAMES_89 = 8 * 11 + 1 # 88 + 1
|
353 |
+
NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
|
354 |
+
NB_FRAMES_105 = 8 * 13 + 1 # 104 + 1
|
355 |
NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
|
356 |
+
NB_FRAMES_121 = 8 * 14 + 1 # 120 + 1
|
357 |
NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
|
358 |
+
NB_FRAMES_137 = 8 * 16 + 1 # 136 + 1
|
359 |
NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
|
360 |
+
NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
|
361 |
NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
|
362 |
NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
|
363 |
+
NB_FRAMES_201 = 8 * 25 + 1 # 200 + 1
|
364 |
+
NB_FRAMES_209 = 8 * 26 + 1 # 208 + 1
|
365 |
+
NB_FRAMES_217 = 8 * 27 + 1 # 216 + 1
|
366 |
NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
|
367 |
+
NB_FRAMES_233 = 8 * 29 + 1 # 232 + 1
|
368 |
+
NB_FRAMES_241 = 8 * 30 + 1 # 240 + 1
|
369 |
+
NB_FRAMES_249 = 8 * 31 + 1 # 248 + 1
|
370 |
NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
|
371 |
+
NB_FRAMES_265 = 8 * 33 + 1 # 264 + 1
|
|
|
|
|
372 |
NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
|
373 |
NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
|
374 |
NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
|
|
|
379 |
NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
|
380 |
NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
|
381 |
|
382 |
+
# ------ HOW BUCKETS WORK:----------
|
383 |
+
# Basically, to train or fine-tune a video model with Finetrainers, we need to specify all the possible accepted videos lengths AND size combinations (buckets), in the form: (BUCKET_CONFIGURATION_1, BUCKET_CONFIGURATION_2, ..., BUCKET_CONFIGURATION_N)
|
384 |
+
# Where a bucket is: (NUMBER_OF_FRAMES_PLUS_ONE, HEIGHT_IN_PIXELS, WIDTH_IN_PIXELS)
|
385 |
+
# For instance, for 2 seconds of a 1024x576 video at 24 frames per second, plus one frame (I think there is always an extra frame for the initial starting image), we would get:
|
386 |
+
# NUMBER_OF_FRAMES_PLUS_ONE = (2*24) + 1 = 48 + 1 = 49
|
387 |
+
# HEIGHT_IN_PIXELS = 576
|
388 |
+
# WIDTH_IN_PIXELS = 1024
|
389 |
+
# -> This would give a bucket like this: (49, 576, 1024)
|
390 |
+
#
|
391 |
+
|
392 |
+
SD_TRAINING_BUCKETS = [
|
393 |
+
(NB_FRAMES_1, SD_16_9_H, SD_16_9_W), # 1
|
394 |
+
(NB_FRAMES_9, SD_16_9_H, SD_16_9_W), # 8 + 1
|
395 |
+
(NB_FRAMES_17, SD_16_9_H, SD_16_9_W), # 16 + 1
|
396 |
+
(NB_FRAMES_33, SD_16_9_H, SD_16_9_W), # 32 + 1
|
397 |
+
(NB_FRAMES_49, SD_16_9_H, SD_16_9_W), # 48 + 1
|
398 |
+
(NB_FRAMES_65, SD_16_9_H, SD_16_9_W), # 64 + 1
|
399 |
+
(NB_FRAMES_73, SD_16_9_H, SD_16_9_W), # 72 + 1
|
400 |
+
(NB_FRAMES_81, SD_16_9_H, SD_16_9_W), # 80 + 1
|
401 |
+
(NB_FRAMES_89, SD_16_9_H, SD_16_9_W), # 88 + 1
|
402 |
+
(NB_FRAMES_97, SD_16_9_H, SD_16_9_W), # 96 + 1
|
403 |
+
(NB_FRAMES_105, SD_16_9_H, SD_16_9_W), # 104 + 1
|
404 |
+
(NB_FRAMES_113, SD_16_9_H, SD_16_9_W), # 112 + 1
|
405 |
+
(NB_FRAMES_121, SD_16_9_H, SD_16_9_W), # 121 + 1
|
406 |
+
(NB_FRAMES_129, SD_16_9_H, SD_16_9_W), # 128 + 1
|
407 |
+
(NB_FRAMES_137, SD_16_9_H, SD_16_9_W), # 136 + 1
|
408 |
+
(NB_FRAMES_145, SD_16_9_H, SD_16_9_W), # 144 + 1
|
409 |
+
(NB_FRAMES_161, SD_16_9_H, SD_16_9_W), # 160 + 1
|
410 |
+
(NB_FRAMES_177, SD_16_9_H, SD_16_9_W), # 176 + 1
|
411 |
+
(NB_FRAMES_193, SD_16_9_H, SD_16_9_W), # 192 + 1
|
412 |
+
(NB_FRAMES_201, SD_16_9_H, SD_16_9_W), # 200 + 1
|
413 |
+
(NB_FRAMES_209, SD_16_9_H, SD_16_9_W), # 208 + 1
|
414 |
+
(NB_FRAMES_217, SD_16_9_H, SD_16_9_W), # 216 + 1
|
415 |
+
(NB_FRAMES_225, SD_16_9_H, SD_16_9_W), # 224 + 1
|
416 |
+
(NB_FRAMES_233, SD_16_9_H, SD_16_9_W), # 232 + 1
|
417 |
+
(NB_FRAMES_241, SD_16_9_H, SD_16_9_W), # 240 + 1
|
418 |
+
(NB_FRAMES_249, SD_16_9_H, SD_16_9_W), # 248 + 1
|
419 |
+
(NB_FRAMES_257, SD_16_9_H, SD_16_9_W), # 256 + 1
|
420 |
+
(NB_FRAMES_265, SD_16_9_H, SD_16_9_W), # 264 + 1
|
421 |
+
(NB_FRAMES_273, SD_16_9_H, SD_16_9_W), # 272 + 1
|
422 |
]
|
423 |
|
424 |
+
# For 1280x720 images and videos (from 1 frame up to 272)
|
425 |
+
MD_TRAINING_BUCKETS = [
|
426 |
+
(NB_FRAMES_1, MD_16_9_H, MD_16_9_W), # 1
|
427 |
+
(NB_FRAMES_9, MD_16_9_H, MD_16_9_W), # 8 + 1
|
428 |
+
(NB_FRAMES_17, MD_16_9_H, MD_16_9_W), # 16 + 1
|
429 |
+
(NB_FRAMES_33, MD_16_9_H, MD_16_9_W), # 32 + 1
|
430 |
+
(NB_FRAMES_49, MD_16_9_H, MD_16_9_W), # 48 + 1
|
431 |
+
(NB_FRAMES_65, MD_16_9_H, MD_16_9_W), # 64 + 1
|
432 |
+
(NB_FRAMES_73, MD_16_9_H, MD_16_9_W), # 72 + 1
|
433 |
+
(NB_FRAMES_81, MD_16_9_H, MD_16_9_W), # 80 + 1
|
434 |
+
(NB_FRAMES_89, MD_16_9_H, MD_16_9_W), # 88 + 1
|
435 |
+
(NB_FRAMES_97, MD_16_9_H, MD_16_9_W), # 96 + 1
|
436 |
+
(NB_FRAMES_105, MD_16_9_H, MD_16_9_W), # 104 + 1
|
437 |
+
(NB_FRAMES_113, MD_16_9_H, MD_16_9_W), # 112 + 1
|
438 |
+
(NB_FRAMES_121, MD_16_9_H, MD_16_9_W), # 121 + 1
|
439 |
+
(NB_FRAMES_129, MD_16_9_H, MD_16_9_W), # 128 + 1
|
440 |
+
(NB_FRAMES_137, MD_16_9_H, MD_16_9_W), # 136 + 1
|
441 |
+
(NB_FRAMES_145, MD_16_9_H, MD_16_9_W), # 144 + 1
|
442 |
+
(NB_FRAMES_161, MD_16_9_H, MD_16_9_W), # 160 + 1
|
443 |
+
(NB_FRAMES_177, MD_16_9_H, MD_16_9_W), # 176 + 1
|
444 |
+
(NB_FRAMES_193, MD_16_9_H, MD_16_9_W), # 192 + 1
|
445 |
+
(NB_FRAMES_201, MD_16_9_H, MD_16_9_W), # 200 + 1
|
446 |
+
(NB_FRAMES_209, MD_16_9_H, MD_16_9_W), # 208 + 1
|
447 |
+
(NB_FRAMES_217, MD_16_9_H, MD_16_9_W), # 216 + 1
|
448 |
+
(NB_FRAMES_225, MD_16_9_H, MD_16_9_W), # 224 + 1
|
449 |
+
(NB_FRAMES_233, MD_16_9_H, MD_16_9_W), # 232 + 1
|
450 |
+
(NB_FRAMES_241, MD_16_9_H, MD_16_9_W), # 240 + 1
|
451 |
+
(NB_FRAMES_249, MD_16_9_H, MD_16_9_W), # 248 + 1
|
452 |
+
(NB_FRAMES_257, MD_16_9_H, MD_16_9_W), # 256 + 1
|
453 |
+
(NB_FRAMES_265, MD_16_9_H, MD_16_9_W), # 264 + 1
|
454 |
+
(NB_FRAMES_273, MD_16_9_H, MD_16_9_W), # 272 + 1
|
455 |
]
|
456 |
|
457 |
+
|
458 |
+
# Model specific default parameters
|
459 |
+
# These are used instead of the previous TRAINING_PRESETS
|
460 |
+
|
461 |
+
# Resolution buckets for different models
|
462 |
+
RESOLUTION_OPTIONS = {
|
463 |
+
"SD (1024x576)": "SD_TRAINING_BUCKETS",
|
464 |
+
"HD (1280x720)": "MD_TRAINING_BUCKETS"
|
465 |
+
}
|
466 |
+
|
467 |
+
# Default parameters for Hunyuan Video
|
468 |
+
HUNYUAN_VIDEO_DEFAULTS = {
|
469 |
+
"lora": {
|
470 |
"learning_rate": 2e-5,
|
|
|
|
|
471 |
"flow_weighting_scheme": "none",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
"lora_rank": DEFAULT_LORA_RANK_STR,
|
473 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
},
|
475 |
+
"control-lora": {
|
476 |
+
"learning_rate": 2e-5,
|
477 |
+
"flow_weighting_scheme": "none",
|
478 |
+
"lora_rank": "128",
|
479 |
+
"lora_alpha": "128",
|
480 |
+
"control_type": "custom",
|
481 |
+
"train_qk_norm": True,
|
482 |
+
"frame_conditioning_type": "index",
|
483 |
+
"frame_conditioning_index": 0,
|
484 |
+
"frame_conditioning_concatenate_mask": True
|
485 |
+
}
|
486 |
+
}
|
487 |
+
|
488 |
+
# Default parameters for LTX Video
|
489 |
+
LTX_VIDEO_DEFAULTS = {
|
490 |
+
"lora": {
|
491 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
492 |
+
"flow_weighting_scheme": "none",
|
493 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
494 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
|
|
|
|
|
|
495 |
},
|
496 |
+
"full-finetune": {
|
|
|
|
|
|
|
|
|
497 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
498 |
+
"flow_weighting_scheme": "logit_normal"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
},
|
500 |
+
"control-lora": {
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
"flow_weighting_scheme": "logit_normal",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
"lora_rank": "128",
|
504 |
"lora_alpha": "128",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
"control_type": "custom",
|
506 |
"train_qk_norm": True,
|
507 |
"frame_conditioning_type": "index",
|
508 |
"frame_conditioning_index": 0,
|
509 |
+
"frame_conditioning_concatenate_mask": True
|
510 |
+
}
|
511 |
+
}
|
512 |
+
|
513 |
+
# Default parameters for Wan
|
514 |
+
WAN_DEFAULTS = {
|
515 |
+
"lora": {
|
516 |
+
"learning_rate": 5e-5,
|
517 |
+
"flow_weighting_scheme": "logit_normal",
|
518 |
+
"lora_rank": "32",
|
519 |
+
"lora_alpha": "32"
|
520 |
},
|
521 |
+
"control-lora": {
|
522 |
+
"learning_rate": 5e-5,
|
523 |
+
"flow_weighting_scheme": "logit_normal",
|
524 |
+
"lora_rank": "32",
|
525 |
+
"lora_alpha": "32",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
"control_type": "custom",
|
527 |
"train_qk_norm": True,
|
528 |
"frame_conditioning_type": "index",
|
529 |
"frame_conditioning_index": 0,
|
530 |
+
"frame_conditioning_concatenate_mask": True
|
|
|
531 |
}
|
532 |
}
|
533 |
|
|
|
555 |
caption_column: str = "prompts.txt"
|
556 |
|
557 |
id_token: Optional[str] = None
|
558 |
+
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SD_TRAINING_BUCKETS)
|
559 |
video_reshape_mode: str = "center"
|
560 |
caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
|
561 |
caption_dropout_technique: str = "empty"
|
|
|
620 |
gradient_accumulation_steps=1,
|
621 |
lora_rank=DEFAULT_LORA_RANK,
|
622 |
lora_alpha=DEFAULT_LORA_ALPHA,
|
623 |
+
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
|
624 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
625 |
flow_weighting_scheme="none", # Hunyuan specific
|
626 |
training_type="lora"
|
|
|
642 |
gradient_accumulation_steps=4,
|
643 |
lora_rank=DEFAULT_LORA_RANK,
|
644 |
lora_alpha=DEFAULT_LORA_ALPHA,
|
645 |
+
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
|
646 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
647 |
flow_weighting_scheme="logit_normal", # LTX specific
|
648 |
training_type="lora"
|
|
|
662 |
gradient_checkpointing=True,
|
663 |
id_token=None,
|
664 |
gradient_accumulation_steps=1,
|
665 |
+
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
|
666 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
667 |
flow_weighting_scheme="logit_normal", # LTX specific
|
668 |
training_type="full-finetune"
|
|
|
685 |
lora_rank=32,
|
686 |
lora_alpha=32,
|
687 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
688 |
+
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
|
689 |
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
690 |
flow_weighting_scheme="logit_normal", # Wan specific
|
691 |
training_type="lora"
|
vms/ui/app_ui.py
CHANGED
@@ -9,8 +9,8 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
9 |
|
10 |
from vms.config import (
|
11 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
12 |
-
|
13 |
-
|
14 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
15 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
16 |
DEFAULT_LEARNING_RATE,
|
@@ -23,6 +23,7 @@ from vms.config import (
|
|
23 |
DEFAULT_NB_TRAINING_STEPS,
|
24 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
25 |
DEFAULT_AUTO_RESUME,
|
|
|
26 |
|
27 |
get_project_paths,
|
28 |
generate_model_project_id,
|
@@ -363,7 +364,6 @@ class AppUI:
|
|
363 |
self.project_tabs["train_tab"].components["resume_btn"],
|
364 |
self.project_tabs["train_tab"].components["stop_btn"],
|
365 |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"],
|
366 |
-
self.project_tabs["train_tab"].components["training_preset"],
|
367 |
self.project_tabs["train_tab"].components["model_type"],
|
368 |
self.project_tabs["train_tab"].components["model_version"],
|
369 |
self.project_tabs["train_tab"].components["training_type"],
|
@@ -377,7 +377,8 @@ class AppUI:
|
|
377 |
self.project_tabs["train_tab"].components["num_gpus"],
|
378 |
self.project_tabs["train_tab"].components["precomputation_items"],
|
379 |
self.project_tabs["train_tab"].components["lr_warmup_steps"],
|
380 |
-
self.project_tabs["train_tab"].components["auto_resume"]
|
|
|
381 |
]
|
382 |
)
|
383 |
|
@@ -485,7 +486,7 @@ class AppUI:
|
|
485 |
|
486 |
# Copy other parameters
|
487 |
for param in ["lora_rank", "lora_alpha", "train_steps",
|
488 |
-
"batch_size", "learning_rate", "save_iterations"
|
489 |
if param in recovery_ui:
|
490 |
ui_state[param] = recovery_ui[param]
|
491 |
|
@@ -544,21 +545,22 @@ class AppUI:
|
|
544 |
model_version_val = available_model_versions[0]
|
545 |
logger.info(f"Using first available model version: {model_version_val}")
|
546 |
|
547 |
-
# IMPORTANT: Create a new list of
|
548 |
-
# This ensures
|
549 |
-
|
550 |
|
551 |
# Update the dropdown choices directly in the UI component
|
552 |
try:
|
553 |
-
self.project_tabs["train_tab"].components["model_version"].choices =
|
554 |
-
logger.info(f"Updated model_version dropdown choices: {len(
|
555 |
except Exception as e:
|
556 |
logger.error(f"Error updating model_version dropdown: {str(e)}")
|
557 |
else:
|
558 |
logger.warning(f"No versions available for model type: {model_type_val}")
|
559 |
-
# Set empty choices to avoid errors
|
560 |
try:
|
561 |
self.project_tabs["train_tab"].components["model_version"].choices = []
|
|
|
562 |
except Exception as e:
|
563 |
logger.error(f"Error setting empty model_version choices: {str(e)}")
|
564 |
|
@@ -577,11 +579,8 @@ class AppUI:
|
|
577 |
training_type_val = list(TRAINING_TYPES.keys())[0]
|
578 |
logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
|
579 |
|
580 |
-
#
|
581 |
-
|
582 |
-
if training_preset not in TRAINING_PRESETS:
|
583 |
-
training_preset = list(TRAINING_PRESETS.keys())[0]
|
584 |
-
logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}")
|
585 |
|
586 |
lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
587 |
lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
@@ -616,7 +615,6 @@ class AppUI:
|
|
616 |
resume_btn,
|
617 |
stop_btn,
|
618 |
delete_checkpoints_btn,
|
619 |
-
training_preset,
|
620 |
model_type_val,
|
621 |
model_version_val,
|
622 |
training_type_val,
|
@@ -630,7 +628,8 @@ class AppUI:
|
|
630 |
num_gpus_val,
|
631 |
precomputation_items_val,
|
632 |
lr_warmup_steps_val,
|
633 |
-
auto_resume_val
|
|
|
634 |
)
|
635 |
|
636 |
def initialize_ui_from_state(self):
|
@@ -650,7 +649,6 @@ class AppUI:
|
|
650 |
|
651 |
# Return values in order matching the outputs in app.load
|
652 |
return (
|
653 |
-
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
654 |
model_type,
|
655 |
model_version,
|
656 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
@@ -659,7 +657,8 @@ class AppUI:
|
|
659 |
ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
660 |
ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
661 |
ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
662 |
-
ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
|
|
|
663 |
)
|
664 |
|
665 |
def update_ui_state(self, **kwargs):
|
|
|
9 |
|
10 |
from vms.config import (
|
11 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
12 |
+
MODEL_TYPES, SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS, TRAINING_TYPES, MODEL_VERSIONS,
|
13 |
+
RESOLUTION_OPTIONS,
|
14 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
15 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
16 |
DEFAULT_LEARNING_RATE,
|
|
|
23 |
DEFAULT_NB_TRAINING_STEPS,
|
24 |
DEFAULT_NB_LR_WARMUP_STEPS,
|
25 |
DEFAULT_AUTO_RESUME,
|
26 |
+
HUNYUAN_VIDEO_DEFAULTS, LTX_VIDEO_DEFAULTS, WAN_DEFAULTS,
|
27 |
|
28 |
get_project_paths,
|
29 |
generate_model_project_id,
|
|
|
364 |
self.project_tabs["train_tab"].components["resume_btn"],
|
365 |
self.project_tabs["train_tab"].components["stop_btn"],
|
366 |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"],
|
|
|
367 |
self.project_tabs["train_tab"].components["model_type"],
|
368 |
self.project_tabs["train_tab"].components["model_version"],
|
369 |
self.project_tabs["train_tab"].components["training_type"],
|
|
|
377 |
self.project_tabs["train_tab"].components["num_gpus"],
|
378 |
self.project_tabs["train_tab"].components["precomputation_items"],
|
379 |
self.project_tabs["train_tab"].components["lr_warmup_steps"],
|
380 |
+
self.project_tabs["train_tab"].components["auto_resume"],
|
381 |
+
self.project_tabs["train_tab"].components["resolution"]
|
382 |
]
|
383 |
)
|
384 |
|
|
|
486 |
|
487 |
# Copy other parameters
|
488 |
for param in ["lora_rank", "lora_alpha", "train_steps",
|
489 |
+
"batch_size", "learning_rate", "save_iterations"]:
|
490 |
if param in recovery_ui:
|
491 |
ui_state[param] = recovery_ui[param]
|
492 |
|
|
|
545 |
model_version_val = available_model_versions[0]
|
546 |
logger.info(f"Using first available model version: {model_version_val}")
|
547 |
|
548 |
+
# IMPORTANT: Create a new list of tuples (label, value) for the dropdown choices
|
549 |
+
# This ensures compatibility with Gradio Dropdown component expectations
|
550 |
+
choices_tuples = [(str(version), str(version)) for version in available_model_versions]
|
551 |
|
552 |
# Update the dropdown choices directly in the UI component
|
553 |
try:
|
554 |
+
self.project_tabs["train_tab"].components["model_version"].choices = choices_tuples
|
555 |
+
logger.info(f"Updated model_version dropdown choices: {len(choices_tuples)} options")
|
556 |
except Exception as e:
|
557 |
logger.error(f"Error updating model_version dropdown: {str(e)}")
|
558 |
else:
|
559 |
logger.warning(f"No versions available for model type: {model_type_val}")
|
560 |
+
# Set empty choices as an empty list of tuples to avoid errors
|
561 |
try:
|
562 |
self.project_tabs["train_tab"].components["model_version"].choices = []
|
563 |
+
logger.info("Set empty model_version dropdown choices")
|
564 |
except Exception as e:
|
565 |
logger.error(f"Error setting empty model_version choices: {str(e)}")
|
566 |
|
|
|
579 |
training_type_val = list(TRAINING_TYPES.keys())[0]
|
580 |
logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
|
581 |
|
582 |
+
# Get resolution value
|
583 |
+
resolution_val = ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
|
|
|
|
|
|
|
584 |
|
585 |
lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
586 |
lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
|
|
615 |
resume_btn,
|
616 |
stop_btn,
|
617 |
delete_checkpoints_btn,
|
|
|
618 |
model_type_val,
|
619 |
model_version_val,
|
620 |
training_type_val,
|
|
|
628 |
num_gpus_val,
|
629 |
precomputation_items_val,
|
630 |
lr_warmup_steps_val,
|
631 |
+
auto_resume_val,
|
632 |
+
resolution_val
|
633 |
)
|
634 |
|
635 |
def initialize_ui_from_state(self):
|
|
|
649 |
|
650 |
# Return values in order matching the outputs in app.load
|
651 |
return (
|
|
|
652 |
model_type,
|
653 |
model_version,
|
654 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
|
|
657 |
ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
658 |
ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
659 |
ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
660 |
+
ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
661 |
+
ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
|
662 |
)
|
663 |
|
664 |
def update_ui_state(self, **kwargs):
|
vms/ui/project/services/training.py
CHANGED
@@ -22,7 +22,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
22 |
from huggingface_hub import upload_folder, create_repo
|
23 |
|
24 |
from vms.config import (
|
25 |
-
TrainingConfig,
|
26 |
STORAGE_PATH, HF_API_TOKEN,
|
27 |
MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS,
|
28 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
@@ -228,7 +228,7 @@ class TrainingService:
|
|
228 |
"batch_size": DEFAULT_BATCH_SIZE,
|
229 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
230 |
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
231 |
-
"
|
232 |
"num_gpus": DEFAULT_NUM_GPUS,
|
233 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
234 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
@@ -354,10 +354,10 @@ class TrainingService:
|
|
354 |
merged_state["training_type"] = default_state["training_type"]
|
355 |
logger.warning(f"Invalid training type in saved state, using default")
|
356 |
|
357 |
-
# Validate
|
358 |
-
if merged_state["
|
359 |
-
merged_state["
|
360 |
-
logger.warning(f"Invalid
|
361 |
|
362 |
# Validate lora_rank is in allowed values
|
363 |
if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
@@ -566,7 +566,6 @@ class TrainingService:
|
|
566 |
learning_rate: float,
|
567 |
save_iterations: int,
|
568 |
repo_id: str,
|
569 |
-
preset_name: str,
|
570 |
training_type: str = DEFAULT_TRAINING_TYPE,
|
571 |
model_version: str = "",
|
572 |
resume_from_checkpoint: Optional[str] = None,
|
@@ -577,7 +576,6 @@ class TrainingService:
|
|
577 |
) -> Tuple[str, str]:
|
578 |
"""Start training with finetrainers"""
|
579 |
|
580 |
-
training_path
|
581 |
self.clear_logs()
|
582 |
|
583 |
if not model_type:
|
@@ -646,11 +644,24 @@ class TrainingService:
|
|
646 |
#if progress:
|
647 |
# progress(0.25, desc="Creating dataset configuration")
|
648 |
|
649 |
-
# Get
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
# Get the custom prompt prefix from the tabs
|
656 |
custom_prompt_prefix = None
|
@@ -1117,7 +1128,7 @@ class TrainingService:
|
|
1117 |
"batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
1118 |
"learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
1119 |
"save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1120 |
-
"
|
1121 |
"repo_id": "", # Default empty repo ID,
|
1122 |
"auto_resume": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
|
1123 |
}
|
@@ -1190,7 +1201,7 @@ class TrainingService:
|
|
1190 |
"batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
|
1191 |
"learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
1192 |
"save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1193 |
-
"
|
1194 |
"auto_resume": params.get("auto_resume", DEFAULT_AUTO_RESUME)
|
1195 |
})
|
1196 |
|
@@ -1211,7 +1222,6 @@ class TrainingService:
|
|
1211 |
save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1212 |
model_version=params.get('model_version', ''),
|
1213 |
repo_id=params.get('repo_id', ''),
|
1214 |
-
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
1215 |
training_type=training_type_internal,
|
1216 |
resume_from_checkpoint="latest"
|
1217 |
)
|
|
|
22 |
from huggingface_hub import upload_folder, create_repo
|
23 |
|
24 |
from vms.config import (
|
25 |
+
TrainingConfig, RESOLUTION_OPTIONS, SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS,
|
26 |
STORAGE_PATH, HF_API_TOKEN,
|
27 |
MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS,
|
28 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
|
|
228 |
"batch_size": DEFAULT_BATCH_SIZE,
|
229 |
"learning_rate": DEFAULT_LEARNING_RATE,
|
230 |
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
231 |
+
"resolution": list(RESOLUTION_OPTIONS.keys())[0],
|
232 |
"num_gpus": DEFAULT_NUM_GPUS,
|
233 |
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
234 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
|
|
|
354 |
merged_state["training_type"] = default_state["training_type"]
|
355 |
logger.warning(f"Invalid training type in saved state, using default")
|
356 |
|
357 |
+
# Validate resolution is in available choices
|
358 |
+
if "resolution" in merged_state and merged_state["resolution"] not in RESOLUTION_OPTIONS:
|
359 |
+
merged_state["resolution"] = default_state["resolution"]
|
360 |
+
logger.warning(f"Invalid resolution in saved state, using default")
|
361 |
|
362 |
# Validate lora_rank is in allowed values
|
363 |
if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
|
|
566 |
learning_rate: float,
|
567 |
save_iterations: int,
|
568 |
repo_id: str,
|
|
|
569 |
training_type: str = DEFAULT_TRAINING_TYPE,
|
570 |
model_version: str = "",
|
571 |
resume_from_checkpoint: Optional[str] = None,
|
|
|
576 |
) -> Tuple[str, str]:
|
577 |
"""Start training with finetrainers"""
|
578 |
|
|
|
579 |
self.clear_logs()
|
580 |
|
581 |
if not model_type:
|
|
|
644 |
#if progress:
|
645 |
# progress(0.25, desc="Creating dataset configuration")
|
646 |
|
647 |
+
# Get resolution configuration from UI state
|
648 |
+
ui_state = self.load_ui_state()
|
649 |
+
resolution_option = ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
|
650 |
+
training_buckets_name = RESOLUTION_OPTIONS.get(resolution_option, "SD_TRAINING_BUCKETS")
|
651 |
+
|
652 |
+
# Determine which buckets to use based on the selected resolution
|
653 |
+
if training_buckets_name == "SD_TRAINING_BUCKETS":
|
654 |
+
training_buckets = SD_TRAINING_BUCKETS
|
655 |
+
elif training_buckets_name == "MD_TRAINING_BUCKETS":
|
656 |
+
training_buckets = MD_TRAINING_BUCKETS
|
657 |
+
else:
|
658 |
+
training_buckets = SD_TRAINING_BUCKETS # Default fallback
|
659 |
+
|
660 |
+
# Determine flow weighting scheme based on model type
|
661 |
+
if model_type == "hunyuan_video":
|
662 |
+
flow_weighting_scheme = "none"
|
663 |
+
else:
|
664 |
+
flow_weighting_scheme = "logit_normal"
|
665 |
|
666 |
# Get the custom prompt prefix from the tabs
|
667 |
custom_prompt_prefix = None
|
|
|
1128 |
"batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
1129 |
"learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
1130 |
"save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1131 |
+
"resolution": ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0]),
|
1132 |
"repo_id": "", # Default empty repo ID,
|
1133 |
"auto_resume": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
|
1134 |
}
|
|
|
1201 |
"batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
|
1202 |
"learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
1203 |
"save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1204 |
+
"resolution": params.get('resolution', list(RESOLUTION_OPTIONS.keys())[0]),
|
1205 |
"auto_resume": params.get("auto_resume", DEFAULT_AUTO_RESUME)
|
1206 |
})
|
1207 |
|
|
|
1222 |
save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1223 |
model_version=params.get('model_version', ''),
|
1224 |
repo_id=params.get('repo_id', ''),
|
|
|
1225 |
training_type=training_type_internal,
|
1226 |
resume_from_checkpoint="latest"
|
1227 |
)
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -13,8 +13,9 @@ from pathlib import Path
|
|
13 |
from vms.utils import BaseTab
|
14 |
from vms.config import (
|
15 |
ASK_USER_TO_DUPLICATE_SPACE,
|
16 |
-
|
17 |
-
|
|
|
18 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
19 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
20 |
DEFAULT_LEARNING_RATE,
|
@@ -29,7 +30,8 @@ from vms.config import (
|
|
29 |
DEFAULT_AUTO_RESUME,
|
30 |
DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
|
31 |
DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
|
32 |
-
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
|
|
|
33 |
)
|
34 |
|
35 |
logger = logging.getLogger(__name__)
|
@@ -50,15 +52,6 @@ class TrainTab(BaseTab):
|
|
50 |
with gr.Row():
|
51 |
self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
|
52 |
|
53 |
-
with gr.Row():
|
54 |
-
with gr.Column():
|
55 |
-
self.components["training_preset"] = gr.Dropdown(
|
56 |
-
choices=list(TRAINING_PRESETS.keys()),
|
57 |
-
label="Training Preset",
|
58 |
-
value=list(TRAINING_PRESETS.keys())[0]
|
59 |
-
)
|
60 |
-
self.components["preset_info"] = gr.Markdown()
|
61 |
-
|
62 |
with gr.Row():
|
63 |
with gr.Column():
|
64 |
# Get the default model type from the first preset
|
@@ -115,6 +108,15 @@ class TrainTab(BaseTab):
|
|
115 |
self.components["model_info"] = gr.Markdown(
|
116 |
value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# LoRA specific parameters (will show/hide based on training type)
|
120 |
with gr.Row(visible=True) as lora_params_row:
|
@@ -140,18 +142,18 @@ class TrainTab(BaseTab):
|
|
140 |
|
141 |
with gr.Accordion("What is LoRA Rank?", open=False):
|
142 |
gr.Markdown("""
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
""")
|
156 |
|
157 |
with gr.Column():
|
@@ -162,32 +164,31 @@ class TrainTab(BaseTab):
|
|
162 |
type="value",
|
163 |
info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
|
164 |
)
|
165 |
-
|
166 |
with gr.Accordion("What is LoRA Alpha?", open=False):
|
167 |
gr.Markdown("""
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
""")
|
180 |
-
|
181 |
|
182 |
# Control specific parameters (will show/hide based on training type)
|
183 |
with gr.Row(visible=False) as control_params_row:
|
184 |
self.components["control_params_row"] = control_params_row
|
185 |
with gr.Column():
|
186 |
gr.Markdown("""
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
""")
|
192 |
|
193 |
# Second row for control parameters
|
@@ -203,10 +204,10 @@ class TrainTab(BaseTab):
|
|
203 |
|
204 |
with gr.Accordion("What is Control Conditioning?", open=False):
|
205 |
gr.Markdown("""
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
""")
|
211 |
|
212 |
with gr.Column():
|
@@ -218,11 +219,11 @@ class TrainTab(BaseTab):
|
|
218 |
|
219 |
with gr.Accordion("What is QK Normalization?", open=False):
|
220 |
gr.Markdown("""
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
""")
|
227 |
|
228 |
with gr.Row(visible=False) as frame_conditioning_row:
|
@@ -237,15 +238,15 @@ class TrainTab(BaseTab):
|
|
237 |
|
238 |
with gr.Accordion("Frame Conditioning Type Explanation", open=False):
|
239 |
gr.Markdown("""
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
""")
|
250 |
|
251 |
with gr.Column():
|
@@ -267,12 +268,12 @@ class TrainTab(BaseTab):
|
|
267 |
|
268 |
with gr.Accordion("What is Frame Mask Concatenation?", open=False):
|
269 |
gr.Markdown("""
|
270 |
-
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
""")
|
277 |
|
278 |
with gr.Column():
|
@@ -448,7 +449,7 @@ class TrainTab(BaseTab):
|
|
448 |
return None
|
449 |
|
450 |
def handle_new_training_start(
|
451 |
-
self,
|
452 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
453 |
save_iterations, repo_id, progress=gr.Progress()
|
454 |
):
|
@@ -469,13 +470,13 @@ class TrainTab(BaseTab):
|
|
469 |
|
470 |
# Start training normally
|
471 |
return self.handle_training_start(
|
472 |
-
|
473 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
474 |
save_iterations, repo_id, progress
|
475 |
)
|
476 |
|
477 |
def handle_resume_training(
|
478 |
-
self,
|
479 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
480 |
save_iterations, repo_id, progress=gr.Progress()
|
481 |
):
|
@@ -490,7 +491,7 @@ class TrainTab(BaseTab):
|
|
490 |
|
491 |
# Start training with the checkpoint
|
492 |
return self.handle_training_start(
|
493 |
-
|
494 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
495 |
save_iterations, repo_id, progress,
|
496 |
resume_from_checkpoint="latest"
|
@@ -498,20 +499,34 @@ class TrainTab(BaseTab):
|
|
498 |
|
499 |
def connect_events(self) -> None:
|
500 |
"""Connect event handlers to UI components"""
|
501 |
-
# Model type change event - Update model version dropdown choices
|
502 |
self.components["model_type"].change(
|
503 |
fn=self.update_model_versions,
|
504 |
inputs=[self.components["model_type"]],
|
505 |
outputs=[self.components["model_version"]]
|
506 |
).then(
|
507 |
-
fn=self.update_model_type_and_version,
|
508 |
inputs=[self.components["model_type"], self.components["model_version"]],
|
509 |
outputs=[]
|
510 |
).then(
|
511 |
-
#
|
512 |
-
fn=self.
|
513 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
514 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
)
|
516 |
|
517 |
# Model version change event
|
@@ -535,7 +550,14 @@ class TrainTab(BaseTab):
|
|
535 |
self.components["batch_size"],
|
536 |
self.components["learning_rate"],
|
537 |
self.components["save_iterations"],
|
538 |
-
self.components["lora_params_row"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
]
|
540 |
)
|
541 |
|
@@ -632,50 +654,17 @@ class TrainTab(BaseTab):
|
|
632 |
outputs=[]
|
633 |
)
|
634 |
|
635 |
-
#
|
636 |
-
self.components["
|
637 |
-
fn=lambda v: self.app.update_ui_state(
|
638 |
-
inputs=[self.components["
|
639 |
outputs=[]
|
640 |
-
).then(
|
641 |
-
fn=self.update_training_params,
|
642 |
-
inputs=[self.components["training_preset"]],
|
643 |
-
outputs=[
|
644 |
-
self.components["model_type"],
|
645 |
-
self.components["training_type"],
|
646 |
-
self.components["lora_rank"],
|
647 |
-
self.components["lora_alpha"],
|
648 |
-
self.components["train_steps"],
|
649 |
-
self.components["batch_size"],
|
650 |
-
self.components["learning_rate"],
|
651 |
-
self.components["save_iterations"],
|
652 |
-
self.components["preset_info"],
|
653 |
-
self.components["lora_params_row"],
|
654 |
-
self.components["lora_settings_row"],
|
655 |
-
self.components["num_gpus"],
|
656 |
-
self.components["precomputation_items"],
|
657 |
-
self.components["lr_warmup_steps"],
|
658 |
-
# Add model_version to the outputs
|
659 |
-
self.components["model_version"],
|
660 |
-
# Control parameters rows visibility
|
661 |
-
self.components["control_params_row"],
|
662 |
-
self.components["control_settings_row"],
|
663 |
-
self.components["frame_conditioning_row"],
|
664 |
-
self.components["control_options_row"],
|
665 |
-
# Control parameter values
|
666 |
-
self.components["control_type"],
|
667 |
-
self.components["train_qk_norm"],
|
668 |
-
self.components["frame_conditioning_type"],
|
669 |
-
self.components["frame_conditioning_index"],
|
670 |
-
self.components["frame_conditioning_concatenate_mask"],
|
671 |
-
]
|
672 |
)
|
673 |
|
674 |
# Training control events
|
675 |
self.components["start_btn"].click(
|
676 |
fn=self.handle_new_training_start,
|
677 |
inputs=[
|
678 |
-
self.components["training_preset"],
|
679 |
self.components["model_type"],
|
680 |
self.components["model_version"],
|
681 |
self.components["training_type"],
|
@@ -696,7 +685,6 @@ class TrainTab(BaseTab):
|
|
696 |
self.components["resume_btn"].click(
|
697 |
fn=self.handle_resume_training,
|
698 |
inputs=[
|
699 |
-
self.components["training_preset"],
|
700 |
self.components["model_type"],
|
701 |
self.components["model_version"],
|
702 |
self.components["training_type"],
|
@@ -761,23 +749,25 @@ class TrainTab(BaseTab):
|
|
761 |
# Update UI state with proper model_type first
|
762 |
self.app.update_ui_state(model_type=model_type)
|
763 |
|
764 |
-
#
|
765 |
-
|
|
|
766 |
|
767 |
# Create a new dropdown with the updated choices
|
768 |
-
if not
|
769 |
logger.warning(f"No model versions available for {model_type}, using empty list")
|
770 |
# Return empty dropdown to avoid errors
|
771 |
return gr.Dropdown(choices=[], value=None)
|
772 |
|
773 |
# Ensure default_version is in model_versions
|
774 |
-
|
775 |
-
|
|
|
776 |
logger.info(f"Default version not in choices, using first available: {default_version}")
|
777 |
|
778 |
# Return the updated dropdown
|
779 |
-
logger.info(f"Returning dropdown with {len(
|
780 |
-
return gr.Dropdown(choices=
|
781 |
except Exception as e:
|
782 |
# Log any exceptions for debugging
|
783 |
logger.error(f"Error in update_model_versions: {str(e)}")
|
@@ -785,7 +775,7 @@ class TrainTab(BaseTab):
|
|
785 |
return gr.Dropdown(choices=[], value=None)
|
786 |
|
787 |
def handle_training_start(
|
788 |
-
self,
|
789 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
790 |
save_iterations, repo_id,
|
791 |
progress=gr.Progress(),
|
@@ -844,7 +834,6 @@ class TrainTab(BaseTab):
|
|
844 |
learning_rate,
|
845 |
save_iterations,
|
846 |
repo_id,
|
847 |
-
preset_name=preset,
|
848 |
training_type=training_internal_type,
|
849 |
model_version=model_version,
|
850 |
resume_from_checkpoint=resume_from,
|
@@ -898,14 +887,14 @@ class TrainTab(BaseTab):
|
|
898 |
# Add general information about the selected training type
|
899 |
if training_type == "Full Finetune":
|
900 |
finetune_info = """
|
901 |
-
|
902 |
|
903 |
-
|
904 |
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
"""
|
910 |
model_info = finetune_info + "\n\n" + model_info
|
911 |
|
@@ -925,6 +914,8 @@ class TrainTab(BaseTab):
|
|
925 |
self.components["batch_size"]: params["batch_size"],
|
926 |
self.components["learning_rate"]: params["learning_rate"],
|
927 |
self.components["save_iterations"]: params["save_iterations"],
|
|
|
|
|
928 |
self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
|
929 |
self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
|
930 |
self.components["control_params_row"]: gr.Row(visible=show_control_params),
|
@@ -936,11 +927,11 @@ class TrainTab(BaseTab):
|
|
936 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
937 |
"""Get information about the selected model type and training method"""
|
938 |
if model_type == "HunyuanVideo":
|
939 |
-
base_info = """
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
|
945 |
if training_type == "LoRA Finetune":
|
946 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
@@ -952,10 +943,10 @@ class TrainTab(BaseTab):
|
|
952 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
953 |
|
954 |
elif model_type == "LTX-Video":
|
955 |
-
base_info = """
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
|
960 |
if training_type == "LoRA Finetune":
|
961 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
@@ -967,10 +958,10 @@ class TrainTab(BaseTab):
|
|
967 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
968 |
|
969 |
elif model_type == "Wan":
|
970 |
-
base_info = """
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
|
975 |
if training_type == "LoRA Finetune":
|
976 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
@@ -986,168 +977,30 @@ class TrainTab(BaseTab):
|
|
986 |
|
987 |
def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
|
988 |
"""Get default training parameters for model type"""
|
989 |
-
#
|
990 |
-
|
991 |
-
preset for preset_name, preset in TRAINING_PRESETS.items()
|
992 |
-
if preset["model_type"] == model_type and preset["training_type"] == training_type
|
993 |
-
]
|
994 |
-
|
995 |
-
if matching_presets:
|
996 |
-
# Use the first matching preset
|
997 |
-
preset = matching_presets[0]
|
998 |
-
return {
|
999 |
-
"train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
1000 |
-
"batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE),
|
1001 |
-
"learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE),
|
1002 |
-
"save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
1003 |
-
"lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
1004 |
-
"lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
1005 |
-
}
|
1006 |
-
|
1007 |
-
# Default fallbacks
|
1008 |
-
if model_type == "hunyuan_video":
|
1009 |
-
return {
|
1010 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
1011 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
1012 |
-
"learning_rate": 2e-5,
|
1013 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
1014 |
-
"lora_rank": DEFAULT_LORA_RANK_STR,
|
1015 |
-
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
1016 |
-
}
|
1017 |
-
elif model_type == "ltx_video":
|
1018 |
-
return {
|
1019 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
1020 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
1021 |
-
"learning_rate": DEFAULT_LEARNING_RATE,
|
1022 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
1023 |
-
"lora_rank": DEFAULT_LORA_RANK_STR,
|
1024 |
-
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
1025 |
-
}
|
1026 |
-
elif model_type == "wan":
|
1027 |
-
return {
|
1028 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
1029 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
1030 |
-
"learning_rate": 5e-5,
|
1031 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
1032 |
-
"lora_rank": "32",
|
1033 |
-
"lora_alpha": "32"
|
1034 |
-
}
|
1035 |
-
else:
|
1036 |
-
# Generic defaults
|
1037 |
-
return {
|
1038 |
-
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
1039 |
-
"batch_size": DEFAULT_BATCH_SIZE,
|
1040 |
-
"learning_rate": DEFAULT_LEARNING_RATE,
|
1041 |
-
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
1042 |
-
"lora_rank": DEFAULT_LORA_RANK_STR,
|
1043 |
-
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
1044 |
-
}
|
1045 |
-
|
1046 |
-
def update_training_params(self, preset_name: str) -> Tuple:
|
1047 |
-
"""Update UI components based on selected preset while preserving custom settings"""
|
1048 |
-
preset = TRAINING_PRESETS[preset_name]
|
1049 |
-
|
1050 |
-
# Load current UI state to check if user has customized values
|
1051 |
-
current_state = self.app.load_ui_values()
|
1052 |
-
|
1053 |
-
# Find the display name that maps to our model type
|
1054 |
-
model_display_name = next(
|
1055 |
-
key for key, value in MODEL_TYPES.items()
|
1056 |
-
if value == preset["model_type"]
|
1057 |
-
)
|
1058 |
-
|
1059 |
-
# Find the display name that maps to our training type
|
1060 |
-
training_display_name = next(
|
1061 |
-
key for key, value in TRAINING_TYPES.items()
|
1062 |
-
if value == preset["training_type"]
|
1063 |
-
)
|
1064 |
-
|
1065 |
-
# Get preset description for display
|
1066 |
-
description = preset.get("description", "")
|
1067 |
-
|
1068 |
-
# Get max values from buckets
|
1069 |
-
buckets = preset["training_buckets"]
|
1070 |
-
max_frames = max(frames for frames, _, _ in buckets)
|
1071 |
-
max_height = max(height for _, height, _ in buckets)
|
1072 |
-
max_width = max(width for _, _, width in buckets)
|
1073 |
-
bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
|
1074 |
-
|
1075 |
-
info_text = f"{description}{bucket_info}"
|
1076 |
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
|
1084 |
-
#
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
|
1094 |
|
1095 |
-
#
|
1096 |
-
|
1097 |
-
train_qk_norm_val = current_state.get("train_qk_norm") if current_state.get("train_qk_norm") != preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM) else preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
|
1098 |
-
frame_conditioning_type_val = current_state.get("frame_conditioning_type") if current_state.get("frame_conditioning_type") != preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE) else preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
|
1099 |
-
frame_conditioning_index_val = current_state.get("frame_conditioning_index") if current_state.get("frame_conditioning_index") != preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX) else preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
|
1100 |
-
frame_conditioning_concatenate_mask_val = current_state.get("frame_conditioning_concatenate_mask") if current_state.get("frame_conditioning_concatenate_mask") != preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) else preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
|
1101 |
|
1102 |
-
|
1103 |
-
model_versions = self.get_model_version_choices(model_display_name)
|
1104 |
-
default_model_version = self.get_default_model_version(model_display_name)
|
1105 |
-
|
1106 |
-
# Ensure we have valid choices and values
|
1107 |
-
if not model_versions:
|
1108 |
-
logger.warning(f"No versions found for {model_display_name}, using empty list")
|
1109 |
-
model_versions = []
|
1110 |
-
default_model_version = None
|
1111 |
-
elif default_model_version not in model_versions and model_versions:
|
1112 |
-
default_model_version = model_versions[0]
|
1113 |
-
logger.info(f"Reset default version to first available: {default_model_version}")
|
1114 |
-
|
1115 |
-
# Ensure model_versions is a simple list of strings
|
1116 |
-
model_versions = [str(version) for version in model_versions]
|
1117 |
-
|
1118 |
-
# Create the model version dropdown update
|
1119 |
-
model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
|
1120 |
-
|
1121 |
-
# Return values in the same order as the output components listed in line 644
|
1122 |
-
# Make sure we return exactly 24 values to match what's expected
|
1123 |
-
return (
|
1124 |
-
model_display_name, # model_type
|
1125 |
-
training_display_name, # training_type
|
1126 |
-
lora_rank_val, # lora_rank
|
1127 |
-
lora_alpha_val, # lora_alpha
|
1128 |
-
train_steps_val, # train_steps
|
1129 |
-
batch_size_val, # batch_size
|
1130 |
-
learning_rate_val, # learning_rate
|
1131 |
-
save_iterations_val, # save_iterations
|
1132 |
-
info_text, # preset_info
|
1133 |
-
gr.Row(visible=show_lora_params), # lora_params_row
|
1134 |
-
gr.Row(visible=show_lora_params), # lora_settings_row (added missing row)
|
1135 |
-
num_gpus_val, # num_gpus
|
1136 |
-
precomputation_items_val, # precomputation_items
|
1137 |
-
lr_warmup_steps_val, # lr_warmup_steps
|
1138 |
-
model_version_update, # model_version
|
1139 |
-
# Control parameters rows visibility
|
1140 |
-
gr.Row(visible=show_control_params), # control_params_row
|
1141 |
-
gr.Row(visible=show_control_params), # control_settings_row
|
1142 |
-
gr.Row(visible=show_control_params), # frame_conditioning_row
|
1143 |
-
gr.Row(visible=show_control_params), # control_options_row
|
1144 |
-
# Control parameter values
|
1145 |
-
control_type_val, # control_type
|
1146 |
-
train_qk_norm_val, # train_qk_norm
|
1147 |
-
frame_conditioning_type_val, # frame_conditioning_type
|
1148 |
-
frame_conditioning_index_val, # frame_conditioning_index
|
1149 |
-
frame_conditioning_concatenate_mask_val, # frame_conditioning_concatenate_mask
|
1150 |
-
)
|
1151 |
|
1152 |
|
1153 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
|
|
13 |
from vms.utils import BaseTab
|
14 |
from vms.config import (
|
15 |
ASK_USER_TO_DUPLICATE_SPACE,
|
16 |
+
SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS,
|
17 |
+
RESOLUTION_OPTIONS,
|
18 |
+
TRAINING_TYPES, MODEL_TYPES, MODEL_VERSIONS,
|
19 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
20 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
21 |
DEFAULT_LEARNING_RATE,
|
|
|
30 |
DEFAULT_AUTO_RESUME,
|
31 |
DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
|
32 |
DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
|
33 |
+
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
|
34 |
+
HUNYUAN_VIDEO_DEFAULTS, LTX_VIDEO_DEFAULTS, WAN_DEFAULTS
|
35 |
)
|
36 |
|
37 |
logger = logging.getLogger(__name__)
|
|
|
52 |
with gr.Row():
|
53 |
self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
with gr.Row():
|
56 |
with gr.Column():
|
57 |
# Get the default model type from the first preset
|
|
|
108 |
self.components["model_info"] = gr.Markdown(
|
109 |
value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
|
110 |
)
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
with gr.Column():
|
114 |
+
self.components["resolution"] = gr.Dropdown(
|
115 |
+
choices=list(RESOLUTION_OPTIONS.keys()),
|
116 |
+
label="Resolution",
|
117 |
+
value=list(RESOLUTION_OPTIONS.keys())[0],
|
118 |
+
info="Select the resolution for training videos"
|
119 |
+
)
|
120 |
|
121 |
# LoRA specific parameters (will show/hide based on training type)
|
122 |
with gr.Row(visible=True) as lora_params_row:
|
|
|
142 |
|
143 |
with gr.Accordion("What is LoRA Rank?", open=False):
|
144 |
gr.Markdown("""
|
145 |
+
**LoRA Rank** determines the complexity of the LoRA adapters:
|
146 |
+
|
147 |
+
- **Lower rank (16-32)**: Smaller file size, faster training, but less expressive
|
148 |
+
- **Medium rank (64-128)**: Good balance between quality and file size
|
149 |
+
- **Higher rank (256-1024)**: More expressive adapters, better quality but larger file size
|
150 |
+
|
151 |
+
Think of rank as the "capacity" of your adapter. Higher ranks can learn more complex modifications to the base model but require more VRAM during training and result in larger files.
|
152 |
+
|
153 |
+
**Quick guide:**
|
154 |
+
- For Wan models: Use 32-64 (Wan models work well with lower ranks)
|
155 |
+
- For LTX-Video: Use 128-256
|
156 |
+
- For Hunyuan Video: Use 128
|
157 |
""")
|
158 |
|
159 |
with gr.Column():
|
|
|
164 |
type="value",
|
165 |
info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
|
166 |
)
|
|
|
167 |
with gr.Accordion("What is LoRA Alpha?", open=False):
|
168 |
gr.Markdown("""
|
169 |
+
**LoRA Alpha** controls the effective scale of the LoRA updates:
|
170 |
+
|
171 |
+
- The actual scaling factor is calculated as `alpha ÷ rank`
|
172 |
+
- Usually set to match the rank value (alpha = rank)
|
173 |
+
- Higher alpha = stronger effect from the adapters
|
174 |
+
- Lower alpha = more subtle adapter influence
|
175 |
+
|
176 |
+
**Best practice:**
|
177 |
+
- For most cases, set alpha equal to rank
|
178 |
+
- For more aggressive training, set alpha higher than rank
|
179 |
+
- For more conservative training, set alpha lower than rank
|
180 |
""")
|
181 |
+
|
182 |
|
183 |
# Control specific parameters (will show/hide based on training type)
|
184 |
with gr.Row(visible=False) as control_params_row:
|
185 |
self.components["control_params_row"] = control_params_row
|
186 |
with gr.Column():
|
187 |
gr.Markdown("""
|
188 |
+
## 🖼️ Control Training Settings
|
189 |
+
|
190 |
+
Control training enables **image-to-video generation** by teaching the model how to use an image as a guide for video creation.
|
191 |
+
This is ideal for turning still images into dynamic videos while preserving composition, style, and content.
|
192 |
""")
|
193 |
|
194 |
# Second row for control parameters
|
|
|
204 |
|
205 |
with gr.Accordion("What is Control Conditioning?", open=False):
|
206 |
gr.Markdown("""
|
207 |
+
**Control Conditioning** allows the model to be guided by an input image, adapting the video generation based on the image content. This is used for image-to-video generation where you want to turn an image into a moving video while maintaining its style, composition or content.
|
208 |
+
|
209 |
+
- **canny**: Uses edge detection to extract outlines from images for structure-preserving video generation
|
210 |
+
- **custom**: Direct image conditioning without preprocessing, preserving more image details
|
211 |
""")
|
212 |
|
213 |
with gr.Column():
|
|
|
219 |
|
220 |
with gr.Accordion("What is QK Normalization?", open=False):
|
221 |
gr.Markdown("""
|
222 |
+
**QK Normalization** refers to normalizing the query and key values in the attention mechanism of transformers.
|
223 |
+
|
224 |
+
- When enabled, allows the model to better integrate control signals with content generation
|
225 |
+
- Improves training stability for control models
|
226 |
+
- Generally recommended for control training, especially with image conditioning
|
227 |
""")
|
228 |
|
229 |
with gr.Row(visible=False) as frame_conditioning_row:
|
|
|
238 |
|
239 |
with gr.Accordion("Frame Conditioning Type Explanation", open=False):
|
240 |
gr.Markdown("""
|
241 |
+
**Frame Conditioning Types** determine which frames in the video receive image conditioning:
|
242 |
+
|
243 |
+
- **index**: Only applies conditioning to a single frame at the specified index
|
244 |
+
- **prefix**: Applies conditioning to all frames before a certain point
|
245 |
+
- **random**: Randomly selects frames to receive conditioning during training
|
246 |
+
- **first_and_last**: Only applies conditioning to the first and last frames
|
247 |
+
- **full**: Applies conditioning to all frames in the video
|
248 |
+
|
249 |
+
For image-to-video tasks, 'index' (usually with index 0) is most common as it conditions only the first frame.
|
250 |
""")
|
251 |
|
252 |
with gr.Column():
|
|
|
268 |
|
269 |
with gr.Accordion("What is Frame Mask Concatenation?", open=False):
|
270 |
gr.Markdown("""
|
271 |
+
**Frame Mask Concatenation** adds an additional channel to the control signal that indicates which frames are being conditioned:
|
272 |
|
273 |
+
- Creates a binary mask (0/1) indicating which frames receive conditioning
|
274 |
+
- Helps the model distinguish between conditioned and unconditioned frames
|
275 |
+
- Particularly useful for 'index' conditioning where only select frames are conditioned
|
276 |
+
- Generally improves temporal consistency between conditioned and unconditioned frames
|
277 |
""")
|
278 |
|
279 |
with gr.Column():
|
|
|
449 |
return None
|
450 |
|
451 |
def handle_new_training_start(
|
452 |
+
self, model_type, model_version, training_type,
|
453 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
454 |
save_iterations, repo_id, progress=gr.Progress()
|
455 |
):
|
|
|
470 |
|
471 |
# Start training normally
|
472 |
return self.handle_training_start(
|
473 |
+
model_type, model_version, training_type,
|
474 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
475 |
save_iterations, repo_id, progress
|
476 |
)
|
477 |
|
478 |
def handle_resume_training(
|
479 |
+
self, model_type, model_version, training_type,
|
480 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
481 |
save_iterations, repo_id, progress=gr.Progress()
|
482 |
):
|
|
|
491 |
|
492 |
# Start training with the checkpoint
|
493 |
return self.handle_training_start(
|
494 |
+
model_type, model_version, training_type,
|
495 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
496 |
save_iterations, repo_id, progress,
|
497 |
resume_from_checkpoint="latest"
|
|
|
499 |
|
500 |
def connect_events(self) -> None:
|
501 |
"""Connect event handlers to UI components"""
|
502 |
+
# Model type change event - Update model version dropdown choices and default parameters
|
503 |
self.components["model_type"].change(
|
504 |
fn=self.update_model_versions,
|
505 |
inputs=[self.components["model_type"]],
|
506 |
outputs=[self.components["model_version"]]
|
507 |
).then(
|
508 |
+
fn=self.update_model_type_and_version,
|
509 |
inputs=[self.components["model_type"], self.components["model_version"]],
|
510 |
outputs=[]
|
511 |
).then(
|
512 |
+
# Update model info and recommended default values based on model and training type
|
513 |
+
fn=self.update_model_info,
|
514 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
515 |
+
outputs=[
|
516 |
+
self.components["model_info"],
|
517 |
+
self.components["train_steps"],
|
518 |
+
self.components["batch_size"],
|
519 |
+
self.components["learning_rate"],
|
520 |
+
self.components["save_iterations"],
|
521 |
+
self.components["lora_params_row"],
|
522 |
+
self.components["lora_settings_row"],
|
523 |
+
self.components["control_params_row"],
|
524 |
+
self.components["control_settings_row"],
|
525 |
+
self.components["frame_conditioning_row"],
|
526 |
+
self.components["control_options_row"],
|
527 |
+
self.components["lora_rank"],
|
528 |
+
self.components["lora_alpha"]
|
529 |
+
]
|
530 |
)
|
531 |
|
532 |
# Model version change event
|
|
|
550 |
self.components["batch_size"],
|
551 |
self.components["learning_rate"],
|
552 |
self.components["save_iterations"],
|
553 |
+
self.components["lora_params_row"],
|
554 |
+
self.components["lora_settings_row"],
|
555 |
+
self.components["control_params_row"],
|
556 |
+
self.components["control_settings_row"],
|
557 |
+
self.components["frame_conditioning_row"],
|
558 |
+
self.components["control_options_row"],
|
559 |
+
self.components["lora_rank"],
|
560 |
+
self.components["lora_alpha"]
|
561 |
]
|
562 |
)
|
563 |
|
|
|
654 |
outputs=[]
|
655 |
)
|
656 |
|
657 |
+
# Resolution change event
|
658 |
+
self.components["resolution"].change(
|
659 |
+
fn=lambda v: self.app.update_ui_state(resolution=v),
|
660 |
+
inputs=[self.components["resolution"]],
|
661 |
outputs=[]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
)
|
663 |
|
664 |
# Training control events
|
665 |
self.components["start_btn"].click(
|
666 |
fn=self.handle_new_training_start,
|
667 |
inputs=[
|
|
|
668 |
self.components["model_type"],
|
669 |
self.components["model_version"],
|
670 |
self.components["training_type"],
|
|
|
685 |
self.components["resume_btn"].click(
|
686 |
fn=self.handle_resume_training,
|
687 |
inputs=[
|
|
|
688 |
self.components["model_type"],
|
689 |
self.components["model_version"],
|
690 |
self.components["training_type"],
|
|
|
749 |
# Update UI state with proper model_type first
|
750 |
self.app.update_ui_state(model_type=model_type)
|
751 |
|
752 |
+
# Create a list of tuples (label, value) for the dropdown choices
|
753 |
+
# This ensures compatibility with Gradio Dropdown component expectations
|
754 |
+
choices_tuples = [(str(version), str(version)) for version in model_versions]
|
755 |
|
756 |
# Create a new dropdown with the updated choices
|
757 |
+
if not choices_tuples:
|
758 |
logger.warning(f"No model versions available for {model_type}, using empty list")
|
759 |
# Return empty dropdown to avoid errors
|
760 |
return gr.Dropdown(choices=[], value=None)
|
761 |
|
762 |
# Ensure default_version is in model_versions
|
763 |
+
string_versions = [str(v) for v in model_versions]
|
764 |
+
if default_version not in string_versions and string_versions:
|
765 |
+
default_version = string_versions[0]
|
766 |
logger.info(f"Default version not in choices, using first available: {default_version}")
|
767 |
|
768 |
# Return the updated dropdown
|
769 |
+
logger.info(f"Returning dropdown with {len(choices_tuples)} choices")
|
770 |
+
return gr.Dropdown(choices=choices_tuples, value=default_version)
|
771 |
except Exception as e:
|
772 |
# Log any exceptions for debugging
|
773 |
logger.error(f"Error in update_model_versions: {str(e)}")
|
|
|
775 |
return gr.Dropdown(choices=[], value=None)
|
776 |
|
777 |
def handle_training_start(
|
778 |
+
self, model_type, model_version, training_type,
|
779 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
780 |
save_iterations, repo_id,
|
781 |
progress=gr.Progress(),
|
|
|
834 |
learning_rate,
|
835 |
save_iterations,
|
836 |
repo_id,
|
|
|
837 |
training_type=training_internal_type,
|
838 |
model_version=model_version,
|
839 |
resume_from_checkpoint=resume_from,
|
|
|
887 |
# Add general information about the selected training type
|
888 |
if training_type == "Full Finetune":
|
889 |
finetune_info = """
|
890 |
+
## 🧠 Full Finetune Mode
|
891 |
|
892 |
+
Full finetune mode trains all parameters of the model, requiring more VRAM but potentially enabling higher quality results.
|
893 |
|
894 |
+
- Requires 20-50GB+ VRAM depending on model
|
895 |
+
- Creates a complete standalone model (~8GB+ file size)
|
896 |
+
- Recommended only for high-end GPUs (A100, H100, etc.)
|
897 |
+
- Not recommended for the larger models like Hunyuan Video on consumer hardware
|
898 |
"""
|
899 |
model_info = finetune_info + "\n\n" + model_info
|
900 |
|
|
|
914 |
self.components["batch_size"]: params["batch_size"],
|
915 |
self.components["learning_rate"]: params["learning_rate"],
|
916 |
self.components["save_iterations"]: params["save_iterations"],
|
917 |
+
self.components["lora_rank"]: params["lora_rank"],
|
918 |
+
self.components["lora_alpha"]: params["lora_alpha"],
|
919 |
self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
|
920 |
self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
|
921 |
self.components["control_params_row"]: gr.Row(visible=show_control_params),
|
|
|
927 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
928 |
"""Get information about the selected model type and training method"""
|
929 |
if model_type == "HunyuanVideo":
|
930 |
+
base_info = """## HunyuanVideo Training
|
931 |
+
- Required VRAM: ~48GB minimum
|
932 |
+
- Recommended batch size: 1-2
|
933 |
+
- Typical training time: 2-4 hours
|
934 |
+
- Default resolution: 49x512x768"""
|
935 |
|
936 |
if training_type == "LoRA Finetune":
|
937 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
943 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
944 |
|
945 |
elif model_type == "LTX-Video":
|
946 |
+
base_info = """## LTX-Video Training
|
947 |
+
- Recommended batch size: 1-4
|
948 |
+
- Typical training time: 1-3 hours
|
949 |
+
- Default resolution: 49x512x768"""
|
950 |
|
951 |
if training_type == "LoRA Finetune":
|
952 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
958 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
959 |
|
960 |
elif model_type == "Wan":
|
961 |
+
base_info = """## Wan2.1 Training
|
962 |
+
- Recommended batch size: 1-4
|
963 |
+
- Typical training time: 1-3 hours
|
964 |
+
- Default resolution: 49x512x768"""
|
965 |
|
966 |
if training_type == "LoRA Finetune":
|
967 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
|
|
977 |
|
978 |
def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
|
979 |
"""Get default training parameters for model type"""
|
980 |
+
# Use model-specific defaults based on model_type and training_type
|
981 |
+
model_defaults = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
982 |
|
983 |
+
if model_type == "hunyuan_video" and training_type in HUNYUAN_VIDEO_DEFAULTS:
|
984 |
+
model_defaults = HUNYUAN_VIDEO_DEFAULTS[training_type]
|
985 |
+
elif model_type == "ltx_video" and training_type in LTX_VIDEO_DEFAULTS:
|
986 |
+
model_defaults = LTX_VIDEO_DEFAULTS[training_type]
|
987 |
+
elif model_type == "wan" and training_type in WAN_DEFAULTS:
|
988 |
+
model_defaults = WAN_DEFAULTS[training_type]
|
989 |
|
990 |
+
# Build the complete params dict with defaults plus model-specific overrides
|
991 |
+
params = {
|
992 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
993 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
994 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
995 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
996 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
997 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
998 |
+
}
|
|
|
999 |
|
1000 |
+
# Override with model-specific values
|
1001 |
+
params.update(model_defaults)
|
|
|
|
|
|
|
|
|
1002 |
|
1003 |
+
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1004 |
|
1005 |
|
1006 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|