jbilcke-hf HF Staff commited on
Commit
48d6121
·
1 Parent(s): 6fff6df

time to test iamge conditioning

Browse files
docs/gradio/external_plugin--gradio_modal.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Description du projet
2
+ ---------------------
3
+
4
+ `gradio_modal`
5
+ ==============
6
+
7
+ [![PyPI - Version](https://pypi-camo.freetls.fastly.net/19d01702f9691477566e07fbd3c8eb08188e6eae/68747470733a2f2f696d672e736869656c64732e696f2f707970692f762f67726164696f5f6d6f64616c)](https://pypi.org/project/gradio_modal/)
8
+
9
+ A popup modal component
10
+
11
+ Installation
12
+ ------------
13
+
14
+ pip install gradio\_modal
15
+
16
+ Usage
17
+ -----
18
+
19
+ import gradio as gr
20
+ from gradio\_modal import Modal
21
+
22
+ with gr.Blocks() as demo:
23
+ with gr.Tab("Tab 1"):
24
+ text\_1 \= gr.Textbox(label\="Input 1")
25
+ text\_2 \= gr.Textbox(label\="Input 2")
26
+ text\_1.submit(lambda x:x, text\_1, text\_2)
27
+ show\_btn \= gr.Button("Show Modal")
28
+ show\_btn2 \= gr.Button("Show Modal 2")
29
+ gr.Examples(
30
+ \[\["Text 1", "Text 2"\], \["Text 3", "Text 4"\]\],
31
+ inputs\=\[text\_1, text\_2\],
32
+ )
33
+ with gr.Tab("Tab 2"):
34
+ gr.Markdown("This is tab 2")
35
+ with Modal(visible\=False) as modal:
36
+ for i in range(5):
37
+ gr.Markdown("Hello world!")
38
+ with Modal(visible\=False) as modal2:
39
+ for i in range(100):
40
+ gr.Markdown("Hello world!")
41
+ show\_btn.click(lambda: Modal(visible\=True), None, modal)
42
+ show\_btn2.click(lambda: Modal(visible\=True), None, modal2)
43
+
44
+ if \_\_name\_\_ \== "\_\_main\_\_":
45
+ demo.launch()
46
+
47
+ `Modal`
48
+ -------
49
+
50
+ ### Initialization
51
+
52
+ name
53
+
54
+ type
55
+
56
+ default
57
+
58
+ description
59
+
60
+ `visible`
61
+
62
+ bool
63
+
64
+ `False`
65
+
66
+ If False, modal will be hidden.
67
+
68
+ `elem_id`
69
+
70
+ str | None
71
+
72
+ `None`
73
+
74
+ An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
75
+
76
+ `elem_classes`
77
+
78
+ list\[str\] | str | None
79
+
80
+ `None`
81
+
82
+ An optional string or list of strings that are assigned as the class of this component in the HTML DOM. Can be used for targeting CSS styles.
83
+
84
+ `allow_user_close`
85
+
86
+ bool
87
+
88
+ `True`
89
+
90
+ If True, user can close the modal (by clicking outside, clicking the X, or the escape key).
91
+
92
+ `render`
93
+
94
+ bool
95
+
96
+ `True`
97
+
98
+ If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
99
+
100
+ ### Events
101
+
102
+ name
103
+
104
+ description
105
+
106
+ `blur`
107
+
108
+ This listener is triggered when the Modal is unfocused/blurred.
vms/config.py CHANGED
@@ -304,39 +304,71 @@ DEFAULT_VALIDATION_WIDTH = 768
304
  DEFAULT_VALIDATION_NB_FRAMES = 49
305
  DEFAULT_VALIDATION_FRAMERATE = 8
306
 
307
- # it is best to use resolutions that are powers of 8
308
- # The resolution should be divisible by 32
309
- # so we cannot use 1080, 540 etc as they are not divisible by 32
310
- MEDIUM_19_9_RATIO_WIDTH = 768 # 32 * 24
311
- MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
312
-
313
- # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
314
- # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
315
- # 1056 = 32 * 33 (divided by 2: 544 = 17 * 32)
316
- # 1024 = 32 * 32 (divided by 2: 512 = 16 * 32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  # it is important that the resolution buckets properly cover the training dataset,
318
  # or else that we exclude from the dataset videos that are out of this range
319
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
320
 
321
- NB_FRAMES_1 = 1 # 1
322
- NB_FRAMES_9 = 8 + 1 # 8 + 1
323
- NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
324
- NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
325
- NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
326
- NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
327
- NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
328
- NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
 
 
 
329
  NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
 
330
  NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
 
331
  NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
332
- NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
333
  NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
334
  NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
 
 
 
335
  NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
 
 
 
336
  NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
337
- # 256 isn't a lot by the way, especially with 60 FPS videos..
338
- # can we crank it and put more frames in here?
339
-
340
  NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
341
  NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
342
  NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
@@ -347,199 +379,155 @@ NB_FRAMES_369 = 8 * 46 + 1 # 368 + 1
347
  NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
348
  NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
349
 
350
- SMALL_TRAINING_BUCKETS = [
351
- (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
352
- (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
353
- (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
354
- (NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
355
- (NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
356
- (NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
357
- (NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
358
- (NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
359
- (NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
360
- (NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
361
- (NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
362
- (NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
363
- (NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
364
- (NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
365
- (NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
366
- (NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  ]
368
 
369
- MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
370
- MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
371
-
372
- MEDIUM_19_9_RATIO_BUCKETS = [
373
- (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
374
- (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
375
- (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
376
- (NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
377
- (NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
378
- (NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
379
- (NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
380
- (NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
381
- (NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
382
- (NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
383
- (NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
384
- (NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
385
- (NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
386
- (NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
387
- (NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
388
- (NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
 
 
 
 
 
 
 
 
 
 
 
389
  ]
390
 
391
- # Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
392
- TRAINING_PRESETS = {
393
- "HunyuanVideo (normal)": {
394
- "model_type": "hunyuan_video",
395
- "training_type": "lora",
396
- "lora_rank": DEFAULT_LORA_RANK_STR,
397
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
398
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
399
- "batch_size": DEFAULT_BATCH_SIZE,
 
 
 
 
400
  "learning_rate": 2e-5,
401
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
402
- "training_buckets": SMALL_TRAINING_BUCKETS,
403
  "flow_weighting_scheme": "none",
404
- "num_gpus": DEFAULT_NUM_GPUS,
405
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
406
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
407
- },
408
- "LTX-Video (normal)": {
409
- "model_type": "ltx_video",
410
- "training_type": "lora",
411
  "lora_rank": DEFAULT_LORA_RANK_STR,
412
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
413
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
414
- "batch_size": DEFAULT_BATCH_SIZE,
415
- "learning_rate": DEFAULT_LEARNING_RATE,
416
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
417
- "training_buckets": SMALL_TRAINING_BUCKETS,
418
- "flow_weighting_scheme": "none",
419
- "num_gpus": DEFAULT_NUM_GPUS,
420
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
421
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
422
  },
423
- "LTX-Video (16:9, HQ)": {
424
- "model_type": "ltx_video",
425
- "training_type": "lora",
426
- "lora_rank": "256",
427
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
428
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
429
- "batch_size": DEFAULT_BATCH_SIZE,
 
 
 
 
 
 
 
 
 
430
  "learning_rate": DEFAULT_LEARNING_RATE,
431
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
432
- "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
433
- "flow_weighting_scheme": "logit_normal",
434
- "num_gpus": DEFAULT_NUM_GPUS,
435
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
436
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
437
  },
438
- "LTX-Video (Full Finetune)": {
439
- "model_type": "ltx_video",
440
- "training_type": "full-finetune",
441
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
442
- "batch_size": DEFAULT_BATCH_SIZE,
443
  "learning_rate": DEFAULT_LEARNING_RATE,
444
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
445
- "training_buckets": SMALL_TRAINING_BUCKETS,
446
- "flow_weighting_scheme": "logit_normal",
447
- "num_gpus": DEFAULT_NUM_GPUS,
448
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
449
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
450
- },
451
- "Wan-2.1-T2V (normal)": {
452
- "model_type": "wan",
453
- "training_type": "lora",
454
- "lora_rank": "32",
455
- "lora_alpha": "32",
456
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
457
- "batch_size": DEFAULT_BATCH_SIZE,
458
- "learning_rate": 5e-5,
459
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
460
- "training_buckets": SMALL_TRAINING_BUCKETS,
461
- "flow_weighting_scheme": "logit_normal",
462
- "num_gpus": DEFAULT_NUM_GPUS,
463
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
464
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
465
  },
466
- "Wan-2.1-T2V (HQ)": {
467
- "model_type": "wan",
468
- "training_type": "lora",
469
- "lora_rank": "64",
470
- "lora_alpha": "64",
471
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
472
- "batch_size": DEFAULT_BATCH_SIZE,
473
  "learning_rate": DEFAULT_LEARNING_RATE,
474
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
475
- "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
476
- "flow_weighting_scheme": "logit_normal",
477
- "num_gpus": DEFAULT_NUM_GPUS,
478
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
479
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
480
- },
481
- "Wan-2.1-I2V (Control LoRA)": {
482
- "model_type": "wan",
483
- "training_type": "control-lora",
484
- "lora_rank": "32",
485
- "lora_alpha": "32",
486
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
487
- "batch_size": DEFAULT_BATCH_SIZE,
488
- "learning_rate": 5e-5,
489
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
490
- "training_buckets": SMALL_TRAINING_BUCKETS,
491
  "flow_weighting_scheme": "logit_normal",
492
- "num_gpus": DEFAULT_NUM_GPUS,
493
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
494
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
495
- "control_type": "custom",
496
- "train_qk_norm": True,
497
- "frame_conditioning_type": "index",
498
- "frame_conditioning_index": 0,
499
- "frame_conditioning_concatenate_mask": True,
500
- "description": "Image-conditioned video generation with LoRA adapters"
501
- },
502
- "LTX-Video (Control LoRA)": {
503
- "model_type": "ltx_video",
504
- "training_type": "control-lora",
505
  "lora_rank": "128",
506
  "lora_alpha": "128",
507
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
508
- "batch_size": DEFAULT_BATCH_SIZE,
509
- "learning_rate": DEFAULT_LEARNING_RATE,
510
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
511
- "training_buckets": SMALL_TRAINING_BUCKETS,
512
- "flow_weighting_scheme": "logit_normal",
513
- "num_gpus": DEFAULT_NUM_GPUS,
514
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
515
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
516
  "control_type": "custom",
517
  "train_qk_norm": True,
518
  "frame_conditioning_type": "index",
519
  "frame_conditioning_index": 0,
520
- "frame_conditioning_concatenate_mask": True,
521
- "description": "Image-conditioned video generation with LoRA adapters"
 
 
 
 
 
 
 
 
 
522
  },
523
- "HunyuanVideo (Control LoRA)": {
524
- "model_type": "hunyuan_video",
525
- "training_type": "control-lora",
526
- "lora_rank": "128",
527
- "lora_alpha": "128",
528
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
529
- "batch_size": DEFAULT_BATCH_SIZE,
530
- "learning_rate": 2e-5,
531
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
532
- "training_buckets": SMALL_TRAINING_BUCKETS,
533
- "flow_weighting_scheme": "none",
534
- "num_gpus": DEFAULT_NUM_GPUS,
535
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
536
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
537
  "control_type": "custom",
538
  "train_qk_norm": True,
539
  "frame_conditioning_type": "index",
540
  "frame_conditioning_index": 0,
541
- "frame_conditioning_concatenate_mask": True,
542
- "description": "Image-conditioned video generation with HunyuanVideo and LoRA adapters"
543
  }
544
  }
545
 
@@ -567,7 +555,7 @@ class TrainingConfig:
567
  caption_column: str = "prompts.txt"
568
 
569
  id_token: Optional[str] = None
570
- video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
571
  video_reshape_mode: str = "center"
572
  caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
573
  caption_dropout_technique: str = "empty"
@@ -632,7 +620,7 @@ class TrainingConfig:
632
  gradient_accumulation_steps=1,
633
  lora_rank=DEFAULT_LORA_RANK,
634
  lora_alpha=DEFAULT_LORA_ALPHA,
635
- video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
636
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
637
  flow_weighting_scheme="none", # Hunyuan specific
638
  training_type="lora"
@@ -654,7 +642,7 @@ class TrainingConfig:
654
  gradient_accumulation_steps=4,
655
  lora_rank=DEFAULT_LORA_RANK,
656
  lora_alpha=DEFAULT_LORA_ALPHA,
657
- video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
658
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
659
  flow_weighting_scheme="logit_normal", # LTX specific
660
  training_type="lora"
@@ -674,7 +662,7 @@ class TrainingConfig:
674
  gradient_checkpointing=True,
675
  id_token=None,
676
  gradient_accumulation_steps=1,
677
- video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
678
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
679
  flow_weighting_scheme="logit_normal", # LTX specific
680
  training_type="full-finetune"
@@ -697,7 +685,7 @@ class TrainingConfig:
697
  lora_rank=32,
698
  lora_alpha=32,
699
  target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
700
- video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
701
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
702
  flow_weighting_scheme="logit_normal", # Wan specific
703
  training_type="lora"
 
304
  DEFAULT_VALIDATION_NB_FRAMES = 49
305
  DEFAULT_VALIDATION_FRAMERATE = 8
306
 
307
+ # you should use resolutions that are powers of 8
308
+ # using a 16:9 ratio is also super-recommended
309
+
310
+ # SD
311
+ SD_16_9_W = 1024 # 8*128
312
+ SD_16_9_H = 576 # 8*72
313
+ SD_9_16_W = 576 # 8*72
314
+ SD_9_16_H = 1024 # 8*128
315
+
316
+ # MD (720p)
317
+ MD_16_9_W = 1280 # 8*160
318
+ MD_16_9_H = 720 # 8*90
319
+ MD_9_16_W = 720 # 8*90
320
+ MD_9_16_H = 1280 # 8*160
321
+
322
+ # HD (1080p)
323
+ HD_16_9_W = 1920 # 8*240
324
+ HD_16_9_H = 1080 # 8*135
325
+ HD_9_16_W = 1080 # 8*135
326
+ HD_9_16_H = 1920 # 8*240
327
+
328
+ # QHD (2K)
329
+ QHD_16_9_W = 2160 # 8*270
330
+ QHD_16_9_H = 1440 # 8*180
331
+ QHD_9_16_W = 1440 # 8*180
332
+ QHD_9_16_H = 2160 # 8*270
333
+
334
+ # UHD (4K)
335
+ UHD_16_9_W = 3840 # 8*480
336
+ UHD_16_9_H = 2160 # 8*270
337
+ UHD_9_16_W = 2160 # 8*270
338
+ UHD_9_16_H = 3840 # 8*480
339
+
340
  # it is important that the resolution buckets properly cover the training dataset,
341
  # or else that we exclude from the dataset videos that are out of this range
342
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
343
 
344
+ NB_FRAMES_1 = 1 # 1
345
+ NB_FRAMES_9 = 8 + 1 # 8 + 1
346
+ NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
347
+ NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
348
+ NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
349
+ NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
350
+ NB_FRAMES_73 = 8 * 9 + 1 # 72 + 1
351
+ NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
352
+ NB_FRAMES_89 = 8 * 11 + 1 # 88 + 1
353
+ NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
354
+ NB_FRAMES_105 = 8 * 13 + 1 # 104 + 1
355
  NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
356
+ NB_FRAMES_121 = 8 * 14 + 1 # 120 + 1
357
  NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
358
+ NB_FRAMES_137 = 8 * 16 + 1 # 136 + 1
359
  NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
360
+ NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
361
  NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
362
  NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
363
+ NB_FRAMES_201 = 8 * 25 + 1 # 200 + 1
364
+ NB_FRAMES_209 = 8 * 26 + 1 # 208 + 1
365
+ NB_FRAMES_217 = 8 * 27 + 1 # 216 + 1
366
  NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
367
+ NB_FRAMES_233 = 8 * 29 + 1 # 232 + 1
368
+ NB_FRAMES_241 = 8 * 30 + 1 # 240 + 1
369
+ NB_FRAMES_249 = 8 * 31 + 1 # 248 + 1
370
  NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
371
+ NB_FRAMES_265 = 8 * 33 + 1 # 264 + 1
 
 
372
  NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
373
  NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
374
  NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
 
379
  NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
380
  NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
381
 
382
+ # ------ HOW BUCKETS WORK:----------
383
+ # Basically, to train or fine-tune a video model with Finetrainers, we need to specify all the possible accepted videos lengths AND size combinations (buckets), in the form: (BUCKET_CONFIGURATION_1, BUCKET_CONFIGURATION_2, ..., BUCKET_CONFIGURATION_N)
384
+ # Where a bucket is: (NUMBER_OF_FRAMES_PLUS_ONE, HEIGHT_IN_PIXELS, WIDTH_IN_PIXELS)
385
+ # For instance, for 2 seconds of a 1024x576 video at 24 frames per second, plus one frame (I think there is always an extra frame for the initial starting image), we would get:
386
+ # NUMBER_OF_FRAMES_PLUS_ONE = (2*24) + 1 = 48 + 1 = 49
387
+ # HEIGHT_IN_PIXELS = 576
388
+ # WIDTH_IN_PIXELS = 1024
389
+ # -> This would give a bucket like this: (49, 576, 1024)
390
+ #
391
+
392
+ SD_TRAINING_BUCKETS = [
393
+ (NB_FRAMES_1, SD_16_9_H, SD_16_9_W), # 1
394
+ (NB_FRAMES_9, SD_16_9_H, SD_16_9_W), # 8 + 1
395
+ (NB_FRAMES_17, SD_16_9_H, SD_16_9_W), # 16 + 1
396
+ (NB_FRAMES_33, SD_16_9_H, SD_16_9_W), # 32 + 1
397
+ (NB_FRAMES_49, SD_16_9_H, SD_16_9_W), # 48 + 1
398
+ (NB_FRAMES_65, SD_16_9_H, SD_16_9_W), # 64 + 1
399
+ (NB_FRAMES_73, SD_16_9_H, SD_16_9_W), # 72 + 1
400
+ (NB_FRAMES_81, SD_16_9_H, SD_16_9_W), # 80 + 1
401
+ (NB_FRAMES_89, SD_16_9_H, SD_16_9_W), # 88 + 1
402
+ (NB_FRAMES_97, SD_16_9_H, SD_16_9_W), # 96 + 1
403
+ (NB_FRAMES_105, SD_16_9_H, SD_16_9_W), # 104 + 1
404
+ (NB_FRAMES_113, SD_16_9_H, SD_16_9_W), # 112 + 1
405
+ (NB_FRAMES_121, SD_16_9_H, SD_16_9_W), # 121 + 1
406
+ (NB_FRAMES_129, SD_16_9_H, SD_16_9_W), # 128 + 1
407
+ (NB_FRAMES_137, SD_16_9_H, SD_16_9_W), # 136 + 1
408
+ (NB_FRAMES_145, SD_16_9_H, SD_16_9_W), # 144 + 1
409
+ (NB_FRAMES_161, SD_16_9_H, SD_16_9_W), # 160 + 1
410
+ (NB_FRAMES_177, SD_16_9_H, SD_16_9_W), # 176 + 1
411
+ (NB_FRAMES_193, SD_16_9_H, SD_16_9_W), # 192 + 1
412
+ (NB_FRAMES_201, SD_16_9_H, SD_16_9_W), # 200 + 1
413
+ (NB_FRAMES_209, SD_16_9_H, SD_16_9_W), # 208 + 1
414
+ (NB_FRAMES_217, SD_16_9_H, SD_16_9_W), # 216 + 1
415
+ (NB_FRAMES_225, SD_16_9_H, SD_16_9_W), # 224 + 1
416
+ (NB_FRAMES_233, SD_16_9_H, SD_16_9_W), # 232 + 1
417
+ (NB_FRAMES_241, SD_16_9_H, SD_16_9_W), # 240 + 1
418
+ (NB_FRAMES_249, SD_16_9_H, SD_16_9_W), # 248 + 1
419
+ (NB_FRAMES_257, SD_16_9_H, SD_16_9_W), # 256 + 1
420
+ (NB_FRAMES_265, SD_16_9_H, SD_16_9_W), # 264 + 1
421
+ (NB_FRAMES_273, SD_16_9_H, SD_16_9_W), # 272 + 1
422
  ]
423
 
424
+ # For 1280x720 images and videos (from 1 frame up to 272)
425
+ MD_TRAINING_BUCKETS = [
426
+ (NB_FRAMES_1, MD_16_9_H, MD_16_9_W), # 1
427
+ (NB_FRAMES_9, MD_16_9_H, MD_16_9_W), # 8 + 1
428
+ (NB_FRAMES_17, MD_16_9_H, MD_16_9_W), # 16 + 1
429
+ (NB_FRAMES_33, MD_16_9_H, MD_16_9_W), # 32 + 1
430
+ (NB_FRAMES_49, MD_16_9_H, MD_16_9_W), # 48 + 1
431
+ (NB_FRAMES_65, MD_16_9_H, MD_16_9_W), # 64 + 1
432
+ (NB_FRAMES_73, MD_16_9_H, MD_16_9_W), # 72 + 1
433
+ (NB_FRAMES_81, MD_16_9_H, MD_16_9_W), # 80 + 1
434
+ (NB_FRAMES_89, MD_16_9_H, MD_16_9_W), # 88 + 1
435
+ (NB_FRAMES_97, MD_16_9_H, MD_16_9_W), # 96 + 1
436
+ (NB_FRAMES_105, MD_16_9_H, MD_16_9_W), # 104 + 1
437
+ (NB_FRAMES_113, MD_16_9_H, MD_16_9_W), # 112 + 1
438
+ (NB_FRAMES_121, MD_16_9_H, MD_16_9_W), # 121 + 1
439
+ (NB_FRAMES_129, MD_16_9_H, MD_16_9_W), # 128 + 1
440
+ (NB_FRAMES_137, MD_16_9_H, MD_16_9_W), # 136 + 1
441
+ (NB_FRAMES_145, MD_16_9_H, MD_16_9_W), # 144 + 1
442
+ (NB_FRAMES_161, MD_16_9_H, MD_16_9_W), # 160 + 1
443
+ (NB_FRAMES_177, MD_16_9_H, MD_16_9_W), # 176 + 1
444
+ (NB_FRAMES_193, MD_16_9_H, MD_16_9_W), # 192 + 1
445
+ (NB_FRAMES_201, MD_16_9_H, MD_16_9_W), # 200 + 1
446
+ (NB_FRAMES_209, MD_16_9_H, MD_16_9_W), # 208 + 1
447
+ (NB_FRAMES_217, MD_16_9_H, MD_16_9_W), # 216 + 1
448
+ (NB_FRAMES_225, MD_16_9_H, MD_16_9_W), # 224 + 1
449
+ (NB_FRAMES_233, MD_16_9_H, MD_16_9_W), # 232 + 1
450
+ (NB_FRAMES_241, MD_16_9_H, MD_16_9_W), # 240 + 1
451
+ (NB_FRAMES_249, MD_16_9_H, MD_16_9_W), # 248 + 1
452
+ (NB_FRAMES_257, MD_16_9_H, MD_16_9_W), # 256 + 1
453
+ (NB_FRAMES_265, MD_16_9_H, MD_16_9_W), # 264 + 1
454
+ (NB_FRAMES_273, MD_16_9_H, MD_16_9_W), # 272 + 1
455
  ]
456
 
457
+
458
+ # Model specific default parameters
459
+ # These are used instead of the previous TRAINING_PRESETS
460
+
461
+ # Resolution buckets for different models
462
+ RESOLUTION_OPTIONS = {
463
+ "SD (1024x576)": "SD_TRAINING_BUCKETS",
464
+ "HD (1280x720)": "MD_TRAINING_BUCKETS"
465
+ }
466
+
467
+ # Default parameters for Hunyuan Video
468
+ HUNYUAN_VIDEO_DEFAULTS = {
469
+ "lora": {
470
  "learning_rate": 2e-5,
 
 
471
  "flow_weighting_scheme": "none",
 
 
 
 
 
 
 
472
  "lora_rank": DEFAULT_LORA_RANK_STR,
473
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
 
 
 
 
 
 
 
 
 
474
  },
475
+ "control-lora": {
476
+ "learning_rate": 2e-5,
477
+ "flow_weighting_scheme": "none",
478
+ "lora_rank": "128",
479
+ "lora_alpha": "128",
480
+ "control_type": "custom",
481
+ "train_qk_norm": True,
482
+ "frame_conditioning_type": "index",
483
+ "frame_conditioning_index": 0,
484
+ "frame_conditioning_concatenate_mask": True
485
+ }
486
+ }
487
+
488
+ # Default parameters for LTX Video
489
+ LTX_VIDEO_DEFAULTS = {
490
+ "lora": {
491
  "learning_rate": DEFAULT_LEARNING_RATE,
492
+ "flow_weighting_scheme": "none",
493
+ "lora_rank": DEFAULT_LORA_RANK_STR,
494
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
 
 
 
495
  },
496
+ "full-finetune": {
 
 
 
 
497
  "learning_rate": DEFAULT_LEARNING_RATE,
498
+ "flow_weighting_scheme": "logit_normal"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  },
500
+ "control-lora": {
 
 
 
 
 
 
501
  "learning_rate": DEFAULT_LEARNING_RATE,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  "flow_weighting_scheme": "logit_normal",
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  "lora_rank": "128",
504
  "lora_alpha": "128",
 
 
 
 
 
 
 
 
 
505
  "control_type": "custom",
506
  "train_qk_norm": True,
507
  "frame_conditioning_type": "index",
508
  "frame_conditioning_index": 0,
509
+ "frame_conditioning_concatenate_mask": True
510
+ }
511
+ }
512
+
513
+ # Default parameters for Wan
514
+ WAN_DEFAULTS = {
515
+ "lora": {
516
+ "learning_rate": 5e-5,
517
+ "flow_weighting_scheme": "logit_normal",
518
+ "lora_rank": "32",
519
+ "lora_alpha": "32"
520
  },
521
+ "control-lora": {
522
+ "learning_rate": 5e-5,
523
+ "flow_weighting_scheme": "logit_normal",
524
+ "lora_rank": "32",
525
+ "lora_alpha": "32",
 
 
 
 
 
 
 
 
 
526
  "control_type": "custom",
527
  "train_qk_norm": True,
528
  "frame_conditioning_type": "index",
529
  "frame_conditioning_index": 0,
530
+ "frame_conditioning_concatenate_mask": True
 
531
  }
532
  }
533
 
 
555
  caption_column: str = "prompts.txt"
556
 
557
  id_token: Optional[str] = None
558
+ video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SD_TRAINING_BUCKETS)
559
  video_reshape_mode: str = "center"
560
  caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
561
  caption_dropout_technique: str = "empty"
 
620
  gradient_accumulation_steps=1,
621
  lora_rank=DEFAULT_LORA_RANK,
622
  lora_alpha=DEFAULT_LORA_ALPHA,
623
+ video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
624
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
625
  flow_weighting_scheme="none", # Hunyuan specific
626
  training_type="lora"
 
642
  gradient_accumulation_steps=4,
643
  lora_rank=DEFAULT_LORA_RANK,
644
  lora_alpha=DEFAULT_LORA_ALPHA,
645
+ video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
646
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
647
  flow_weighting_scheme="logit_normal", # LTX specific
648
  training_type="lora"
 
662
  gradient_checkpointing=True,
663
  id_token=None,
664
  gradient_accumulation_steps=1,
665
+ video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
666
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
667
  flow_weighting_scheme="logit_normal", # LTX specific
668
  training_type="full-finetune"
 
685
  lora_rank=32,
686
  lora_alpha=32,
687
  target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
688
+ video_resolution_buckets=buckets or SD_TRAINING_BUCKETS,
689
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
690
  flow_weighting_scheme="logit_normal", # Wan specific
691
  training_type="lora"
vms/ui/app_ui.py CHANGED
@@ -9,8 +9,8 @@ from typing import Any, Optional, Dict, List, Union, Tuple
9
 
10
  from vms.config import (
11
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
12
- TRAINING_PRESETS,
13
- MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES, MODEL_VERSIONS,
14
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
15
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
  DEFAULT_LEARNING_RATE,
@@ -23,6 +23,7 @@ from vms.config import (
23
  DEFAULT_NB_TRAINING_STEPS,
24
  DEFAULT_NB_LR_WARMUP_STEPS,
25
  DEFAULT_AUTO_RESUME,
 
26
 
27
  get_project_paths,
28
  generate_model_project_id,
@@ -363,7 +364,6 @@ class AppUI:
363
  self.project_tabs["train_tab"].components["resume_btn"],
364
  self.project_tabs["train_tab"].components["stop_btn"],
365
  self.project_tabs["train_tab"].components["delete_checkpoints_btn"],
366
- self.project_tabs["train_tab"].components["training_preset"],
367
  self.project_tabs["train_tab"].components["model_type"],
368
  self.project_tabs["train_tab"].components["model_version"],
369
  self.project_tabs["train_tab"].components["training_type"],
@@ -377,7 +377,8 @@ class AppUI:
377
  self.project_tabs["train_tab"].components["num_gpus"],
378
  self.project_tabs["train_tab"].components["precomputation_items"],
379
  self.project_tabs["train_tab"].components["lr_warmup_steps"],
380
- self.project_tabs["train_tab"].components["auto_resume"]
 
381
  ]
382
  )
383
 
@@ -485,7 +486,7 @@ class AppUI:
485
 
486
  # Copy other parameters
487
  for param in ["lora_rank", "lora_alpha", "train_steps",
488
- "batch_size", "learning_rate", "save_iterations", "training_preset"]:
489
  if param in recovery_ui:
490
  ui_state[param] = recovery_ui[param]
491
 
@@ -544,21 +545,22 @@ class AppUI:
544
  model_version_val = available_model_versions[0]
545
  logger.info(f"Using first available model version: {model_version_val}")
546
 
547
- # IMPORTANT: Create a new list of simple strings for the dropdown choices
548
- # This ensures each choice is a single string, not a tuple or other structure
549
- simple_choices = [str(version) for version in available_model_versions]
550
 
551
  # Update the dropdown choices directly in the UI component
552
  try:
553
- self.project_tabs["train_tab"].components["model_version"].choices = simple_choices
554
- logger.info(f"Updated model_version dropdown choices: {len(simple_choices)} options")
555
  except Exception as e:
556
  logger.error(f"Error updating model_version dropdown: {str(e)}")
557
  else:
558
  logger.warning(f"No versions available for model type: {model_type_val}")
559
- # Set empty choices to avoid errors
560
  try:
561
  self.project_tabs["train_tab"].components["model_version"].choices = []
 
562
  except Exception as e:
563
  logger.error(f"Error setting empty model_version choices: {str(e)}")
564
 
@@ -577,11 +579,8 @@ class AppUI:
577
  training_type_val = list(TRAINING_TYPES.keys())[0]
578
  logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
579
 
580
- # Validate training preset
581
- training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
582
- if training_preset not in TRAINING_PRESETS:
583
- training_preset = list(TRAINING_PRESETS.keys())[0]
584
- logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}")
585
 
586
  lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
587
  lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
@@ -616,7 +615,6 @@ class AppUI:
616
  resume_btn,
617
  stop_btn,
618
  delete_checkpoints_btn,
619
- training_preset,
620
  model_type_val,
621
  model_version_val,
622
  training_type_val,
@@ -630,7 +628,8 @@ class AppUI:
630
  num_gpus_val,
631
  precomputation_items_val,
632
  lr_warmup_steps_val,
633
- auto_resume_val
 
634
  )
635
 
636
  def initialize_ui_from_state(self):
@@ -650,7 +649,6 @@ class AppUI:
650
 
651
  # Return values in order matching the outputs in app.load
652
  return (
653
- ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
654
  model_type,
655
  model_version,
656
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
@@ -659,7 +657,8 @@ class AppUI:
659
  ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
660
  ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
661
  ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
662
- ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
 
663
  )
664
 
665
  def update_ui_state(self, **kwargs):
 
9
 
10
  from vms.config import (
11
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
12
+ MODEL_TYPES, SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS, TRAINING_TYPES, MODEL_VERSIONS,
13
+ RESOLUTION_OPTIONS,
14
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
15
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
  DEFAULT_LEARNING_RATE,
 
23
  DEFAULT_NB_TRAINING_STEPS,
24
  DEFAULT_NB_LR_WARMUP_STEPS,
25
  DEFAULT_AUTO_RESUME,
26
+ HUNYUAN_VIDEO_DEFAULTS, LTX_VIDEO_DEFAULTS, WAN_DEFAULTS,
27
 
28
  get_project_paths,
29
  generate_model_project_id,
 
364
  self.project_tabs["train_tab"].components["resume_btn"],
365
  self.project_tabs["train_tab"].components["stop_btn"],
366
  self.project_tabs["train_tab"].components["delete_checkpoints_btn"],
 
367
  self.project_tabs["train_tab"].components["model_type"],
368
  self.project_tabs["train_tab"].components["model_version"],
369
  self.project_tabs["train_tab"].components["training_type"],
 
377
  self.project_tabs["train_tab"].components["num_gpus"],
378
  self.project_tabs["train_tab"].components["precomputation_items"],
379
  self.project_tabs["train_tab"].components["lr_warmup_steps"],
380
+ self.project_tabs["train_tab"].components["auto_resume"],
381
+ self.project_tabs["train_tab"].components["resolution"]
382
  ]
383
  )
384
 
 
486
 
487
  # Copy other parameters
488
  for param in ["lora_rank", "lora_alpha", "train_steps",
489
+ "batch_size", "learning_rate", "save_iterations"]:
490
  if param in recovery_ui:
491
  ui_state[param] = recovery_ui[param]
492
 
 
545
  model_version_val = available_model_versions[0]
546
  logger.info(f"Using first available model version: {model_version_val}")
547
 
548
+ # IMPORTANT: Create a new list of tuples (label, value) for the dropdown choices
549
+ # This ensures compatibility with Gradio Dropdown component expectations
550
+ choices_tuples = [(str(version), str(version)) for version in available_model_versions]
551
 
552
  # Update the dropdown choices directly in the UI component
553
  try:
554
+ self.project_tabs["train_tab"].components["model_version"].choices = choices_tuples
555
+ logger.info(f"Updated model_version dropdown choices: {len(choices_tuples)} options")
556
  except Exception as e:
557
  logger.error(f"Error updating model_version dropdown: {str(e)}")
558
  else:
559
  logger.warning(f"No versions available for model type: {model_type_val}")
560
+ # Set empty choices as an empty list of tuples to avoid errors
561
  try:
562
  self.project_tabs["train_tab"].components["model_version"].choices = []
563
+ logger.info("Set empty model_version dropdown choices")
564
  except Exception as e:
565
  logger.error(f"Error setting empty model_version choices: {str(e)}")
566
 
 
579
  training_type_val = list(TRAINING_TYPES.keys())[0]
580
  logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
581
 
582
+ # Get resolution value
583
+ resolution_val = ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
 
 
 
584
 
585
  lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
586
  lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
 
615
  resume_btn,
616
  stop_btn,
617
  delete_checkpoints_btn,
 
618
  model_type_val,
619
  model_version_val,
620
  training_type_val,
 
628
  num_gpus_val,
629
  precomputation_items_val,
630
  lr_warmup_steps_val,
631
+ auto_resume_val,
632
+ resolution_val
633
  )
634
 
635
  def initialize_ui_from_state(self):
 
649
 
650
  # Return values in order matching the outputs in app.load
651
  return (
 
652
  model_type,
653
  model_version,
654
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
 
657
  ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
658
  ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
659
  ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
660
+ ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
661
+ ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
662
  )
663
 
664
  def update_ui_state(self, **kwargs):
vms/ui/project/services/training.py CHANGED
@@ -22,7 +22,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
22
  from huggingface_hub import upload_folder, create_repo
23
 
24
  from vms.config import (
25
- TrainingConfig, TRAINING_PRESETS,
26
  STORAGE_PATH, HF_API_TOKEN,
27
  MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS,
28
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
@@ -228,7 +228,7 @@ class TrainingService:
228
  "batch_size": DEFAULT_BATCH_SIZE,
229
  "learning_rate": DEFAULT_LEARNING_RATE,
230
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
231
- "training_preset": list(TRAINING_PRESETS.keys())[0],
232
  "num_gpus": DEFAULT_NUM_GPUS,
233
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
234
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
@@ -354,10 +354,10 @@ class TrainingService:
354
  merged_state["training_type"] = default_state["training_type"]
355
  logger.warning(f"Invalid training type in saved state, using default")
356
 
357
- # Validate training_preset is in available choices
358
- if merged_state["training_preset"] not in TRAINING_PRESETS:
359
- merged_state["training_preset"] = default_state["training_preset"]
360
- logger.warning(f"Invalid training preset in saved state, using default")
361
 
362
  # Validate lora_rank is in allowed values
363
  if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
@@ -566,7 +566,6 @@ class TrainingService:
566
  learning_rate: float,
567
  save_iterations: int,
568
  repo_id: str,
569
- preset_name: str,
570
  training_type: str = DEFAULT_TRAINING_TYPE,
571
  model_version: str = "",
572
  resume_from_checkpoint: Optional[str] = None,
@@ -577,7 +576,6 @@ class TrainingService:
577
  ) -> Tuple[str, str]:
578
  """Start training with finetrainers"""
579
 
580
- training_path
581
  self.clear_logs()
582
 
583
  if not model_type:
@@ -646,11 +644,24 @@ class TrainingService:
646
  #if progress:
647
  # progress(0.25, desc="Creating dataset configuration")
648
 
649
- # Get preset configuration
650
- preset = TRAINING_PRESETS[preset_name]
651
- training_buckets = preset["training_buckets"]
652
- flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
653
- preset_training_type = preset.get("training_type", "lora")
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
655
  # Get the custom prompt prefix from the tabs
656
  custom_prompt_prefix = None
@@ -1117,7 +1128,7 @@ class TrainingService:
1117
  "batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
1118
  "learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
1119
  "save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1120
- "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
1121
  "repo_id": "", # Default empty repo ID,
1122
  "auto_resume": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
1123
  }
@@ -1190,7 +1201,7 @@ class TrainingService:
1190
  "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
1191
  "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
1192
  "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1193
- "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
1194
  "auto_resume": params.get("auto_resume", DEFAULT_AUTO_RESUME)
1195
  })
1196
 
@@ -1211,7 +1222,6 @@ class TrainingService:
1211
  save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1212
  model_version=params.get('model_version', ''),
1213
  repo_id=params.get('repo_id', ''),
1214
- preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
1215
  training_type=training_type_internal,
1216
  resume_from_checkpoint="latest"
1217
  )
 
22
  from huggingface_hub import upload_folder, create_repo
23
 
24
  from vms.config import (
25
+ TrainingConfig, RESOLUTION_OPTIONS, SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS,
26
  STORAGE_PATH, HF_API_TOKEN,
27
  MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS,
28
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
 
228
  "batch_size": DEFAULT_BATCH_SIZE,
229
  "learning_rate": DEFAULT_LEARNING_RATE,
230
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
231
+ "resolution": list(RESOLUTION_OPTIONS.keys())[0],
232
  "num_gpus": DEFAULT_NUM_GPUS,
233
  "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
234
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
 
354
  merged_state["training_type"] = default_state["training_type"]
355
  logger.warning(f"Invalid training type in saved state, using default")
356
 
357
+ # Validate resolution is in available choices
358
+ if "resolution" in merged_state and merged_state["resolution"] not in RESOLUTION_OPTIONS:
359
+ merged_state["resolution"] = default_state["resolution"]
360
+ logger.warning(f"Invalid resolution in saved state, using default")
361
 
362
  # Validate lora_rank is in allowed values
363
  if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
 
566
  learning_rate: float,
567
  save_iterations: int,
568
  repo_id: str,
 
569
  training_type: str = DEFAULT_TRAINING_TYPE,
570
  model_version: str = "",
571
  resume_from_checkpoint: Optional[str] = None,
 
576
  ) -> Tuple[str, str]:
577
  """Start training with finetrainers"""
578
 
 
579
  self.clear_logs()
580
 
581
  if not model_type:
 
644
  #if progress:
645
  # progress(0.25, desc="Creating dataset configuration")
646
 
647
+ # Get resolution configuration from UI state
648
+ ui_state = self.load_ui_state()
649
+ resolution_option = ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0])
650
+ training_buckets_name = RESOLUTION_OPTIONS.get(resolution_option, "SD_TRAINING_BUCKETS")
651
+
652
+ # Determine which buckets to use based on the selected resolution
653
+ if training_buckets_name == "SD_TRAINING_BUCKETS":
654
+ training_buckets = SD_TRAINING_BUCKETS
655
+ elif training_buckets_name == "MD_TRAINING_BUCKETS":
656
+ training_buckets = MD_TRAINING_BUCKETS
657
+ else:
658
+ training_buckets = SD_TRAINING_BUCKETS # Default fallback
659
+
660
+ # Determine flow weighting scheme based on model type
661
+ if model_type == "hunyuan_video":
662
+ flow_weighting_scheme = "none"
663
+ else:
664
+ flow_weighting_scheme = "logit_normal"
665
 
666
  # Get the custom prompt prefix from the tabs
667
  custom_prompt_prefix = None
 
1128
  "batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
1129
  "learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
1130
  "save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1131
+ "resolution": ui_state.get("resolution", list(RESOLUTION_OPTIONS.keys())[0]),
1132
  "repo_id": "", # Default empty repo ID,
1133
  "auto_resume": ui_state.get("auto_resume", DEFAULT_AUTO_RESUME)
1134
  }
 
1201
  "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
1202
  "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
1203
  "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1204
+ "resolution": params.get('resolution', list(RESOLUTION_OPTIONS.keys())[0]),
1205
  "auto_resume": params.get("auto_resume", DEFAULT_AUTO_RESUME)
1206
  })
1207
 
 
1222
  save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1223
  model_version=params.get('model_version', ''),
1224
  repo_id=params.get('repo_id', ''),
 
1225
  training_type=training_type_internal,
1226
  resume_from_checkpoint="latest"
1227
  )
vms/ui/project/tabs/train_tab.py CHANGED
@@ -13,8 +13,9 @@ from pathlib import Path
13
  from vms.utils import BaseTab
14
  from vms.config import (
15
  ASK_USER_TO_DUPLICATE_SPACE,
16
- SMALL_TRAINING_BUCKETS,
17
- TRAINING_PRESETS, TRAINING_TYPES, MODEL_TYPES, MODEL_VERSIONS,
 
18
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
19
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
20
  DEFAULT_LEARNING_RATE,
@@ -29,7 +30,8 @@ from vms.config import (
29
  DEFAULT_AUTO_RESUME,
30
  DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
31
  DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
32
- DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK
 
33
  )
34
 
35
  logger = logging.getLogger(__name__)
@@ -50,15 +52,6 @@ class TrainTab(BaseTab):
50
  with gr.Row():
51
  self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
52
 
53
- with gr.Row():
54
- with gr.Column():
55
- self.components["training_preset"] = gr.Dropdown(
56
- choices=list(TRAINING_PRESETS.keys()),
57
- label="Training Preset",
58
- value=list(TRAINING_PRESETS.keys())[0]
59
- )
60
- self.components["preset_info"] = gr.Markdown()
61
-
62
  with gr.Row():
63
  with gr.Column():
64
  # Get the default model type from the first preset
@@ -115,6 +108,15 @@ class TrainTab(BaseTab):
115
  self.components["model_info"] = gr.Markdown(
116
  value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
117
  )
 
 
 
 
 
 
 
 
 
118
 
119
  # LoRA specific parameters (will show/hide based on training type)
120
  with gr.Row(visible=True) as lora_params_row:
@@ -140,18 +142,18 @@ class TrainTab(BaseTab):
140
 
141
  with gr.Accordion("What is LoRA Rank?", open=False):
142
  gr.Markdown("""
143
- **LoRA Rank** determines the complexity of the LoRA adapters:
144
-
145
- - **Lower rank (16-32)**: Smaller file size, faster training, but less expressive
146
- - **Medium rank (64-128)**: Good balance between quality and file size
147
- - **Higher rank (256-1024)**: More expressive adapters, better quality but larger file size
148
-
149
- Think of rank as the "capacity" of your adapter. Higher ranks can learn more complex modifications to the base model but require more VRAM during training and result in larger files.
150
-
151
- **Quick guide:**
152
- - For Wan models: Use 32-64 (Wan models work well with lower ranks)
153
- - For LTX-Video: Use 128-256
154
- - For Hunyuan Video: Use 128
155
  """)
156
 
157
  with gr.Column():
@@ -162,32 +164,31 @@ class TrainTab(BaseTab):
162
  type="value",
163
  info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
164
  )
165
-
166
  with gr.Accordion("What is LoRA Alpha?", open=False):
167
  gr.Markdown("""
168
- **LoRA Alpha** controls the effective scale of the LoRA updates:
169
-
170
- - The actual scaling factor is calculated as `alpha ÷ rank`
171
- - Usually set to match the rank value (alpha = rank)
172
- - Higher alpha = stronger effect from the adapters
173
- - Lower alpha = more subtle adapter influence
174
-
175
- **Best practice:**
176
- - For most cases, set alpha equal to rank
177
- - For more aggressive training, set alpha higher than rank
178
- - For more conservative training, set alpha lower than rank
179
  """)
180
-
181
 
182
  # Control specific parameters (will show/hide based on training type)
183
  with gr.Row(visible=False) as control_params_row:
184
  self.components["control_params_row"] = control_params_row
185
  with gr.Column():
186
  gr.Markdown("""
187
- ## 🖼️ Control Training Settings
188
-
189
- Control training enables **image-to-video generation** by teaching the model how to use an image as a guide for video creation.
190
- This is ideal for turning still images into dynamic videos while preserving composition, style, and content.
191
  """)
192
 
193
  # Second row for control parameters
@@ -203,10 +204,10 @@ class TrainTab(BaseTab):
203
 
204
  with gr.Accordion("What is Control Conditioning?", open=False):
205
  gr.Markdown("""
206
- **Control Conditioning** allows the model to be guided by an input image, adapting the video generation based on the image content. This is used for image-to-video generation where you want to turn an image into a moving video while maintaining its style, composition or content.
207
-
208
- - **canny**: Uses edge detection to extract outlines from images for structure-preserving video generation
209
- - **custom**: Direct image conditioning without preprocessing, preserving more image details
210
  """)
211
 
212
  with gr.Column():
@@ -218,11 +219,11 @@ class TrainTab(BaseTab):
218
 
219
  with gr.Accordion("What is QK Normalization?", open=False):
220
  gr.Markdown("""
221
- **QK Normalization** refers to normalizing the query and key values in the attention mechanism of transformers.
222
-
223
- - When enabled, allows the model to better integrate control signals with content generation
224
- - Improves training stability for control models
225
- - Generally recommended for control training, especially with image conditioning
226
  """)
227
 
228
  with gr.Row(visible=False) as frame_conditioning_row:
@@ -237,15 +238,15 @@ class TrainTab(BaseTab):
237
 
238
  with gr.Accordion("Frame Conditioning Type Explanation", open=False):
239
  gr.Markdown("""
240
- **Frame Conditioning Types** determine which frames in the video receive image conditioning:
241
-
242
- - **index**: Only applies conditioning to a single frame at the specified index
243
- - **prefix**: Applies conditioning to all frames before a certain point
244
- - **random**: Randomly selects frames to receive conditioning during training
245
- - **first_and_last**: Only applies conditioning to the first and last frames
246
- - **full**: Applies conditioning to all frames in the video
247
-
248
- For image-to-video tasks, 'index' (usually with index 0) is most common as it conditions only the first frame.
249
  """)
250
 
251
  with gr.Column():
@@ -267,12 +268,12 @@ class TrainTab(BaseTab):
267
 
268
  with gr.Accordion("What is Frame Mask Concatenation?", open=False):
269
  gr.Markdown("""
270
- **Frame Mask Concatenation** adds an additional channel to the control signal that indicates which frames are being conditioned:
271
 
272
- - Creates a binary mask (0/1) indicating which frames receive conditioning
273
- - Helps the model distinguish between conditioned and unconditioned frames
274
- - Particularly useful for 'index' conditioning where only select frames are conditioned
275
- - Generally improves temporal consistency between conditioned and unconditioned frames
276
  """)
277
 
278
  with gr.Column():
@@ -448,7 +449,7 @@ class TrainTab(BaseTab):
448
  return None
449
 
450
  def handle_new_training_start(
451
- self, preset, model_type, model_version, training_type,
452
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
453
  save_iterations, repo_id, progress=gr.Progress()
454
  ):
@@ -469,13 +470,13 @@ class TrainTab(BaseTab):
469
 
470
  # Start training normally
471
  return self.handle_training_start(
472
- preset, model_type, model_version, training_type,
473
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
474
  save_iterations, repo_id, progress
475
  )
476
 
477
  def handle_resume_training(
478
- self, preset, model_type, model_version, training_type,
479
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
480
  save_iterations, repo_id, progress=gr.Progress()
481
  ):
@@ -490,7 +491,7 @@ class TrainTab(BaseTab):
490
 
491
  # Start training with the checkpoint
492
  return self.handle_training_start(
493
- preset, model_type, model_version, training_type,
494
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
495
  save_iterations, repo_id, progress,
496
  resume_from_checkpoint="latest"
@@ -498,20 +499,34 @@ class TrainTab(BaseTab):
498
 
499
  def connect_events(self) -> None:
500
  """Connect event handlers to UI components"""
501
- # Model type change event - Update model version dropdown choices
502
  self.components["model_type"].change(
503
  fn=self.update_model_versions,
504
  inputs=[self.components["model_type"]],
505
  outputs=[self.components["model_version"]]
506
  ).then(
507
- fn=self.update_model_type_and_version, # Add this new function
508
  inputs=[self.components["model_type"], self.components["model_version"]],
509
  outputs=[]
510
  ).then(
511
- # Use get_model_info instead of update_model_info
512
- fn=self.get_model_info,
513
  inputs=[self.components["model_type"], self.components["training_type"]],
514
- outputs=[self.components["model_info"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  )
516
 
517
  # Model version change event
@@ -535,7 +550,14 @@ class TrainTab(BaseTab):
535
  self.components["batch_size"],
536
  self.components["learning_rate"],
537
  self.components["save_iterations"],
538
- self.components["lora_params_row"]
 
 
 
 
 
 
 
539
  ]
540
  )
541
 
@@ -632,50 +654,17 @@ class TrainTab(BaseTab):
632
  outputs=[]
633
  )
634
 
635
- # Training preset change event
636
- self.components["training_preset"].change(
637
- fn=lambda v: self.app.update_ui_state(training_preset=v),
638
- inputs=[self.components["training_preset"]],
639
  outputs=[]
640
- ).then(
641
- fn=self.update_training_params,
642
- inputs=[self.components["training_preset"]],
643
- outputs=[
644
- self.components["model_type"],
645
- self.components["training_type"],
646
- self.components["lora_rank"],
647
- self.components["lora_alpha"],
648
- self.components["train_steps"],
649
- self.components["batch_size"],
650
- self.components["learning_rate"],
651
- self.components["save_iterations"],
652
- self.components["preset_info"],
653
- self.components["lora_params_row"],
654
- self.components["lora_settings_row"],
655
- self.components["num_gpus"],
656
- self.components["precomputation_items"],
657
- self.components["lr_warmup_steps"],
658
- # Add model_version to the outputs
659
- self.components["model_version"],
660
- # Control parameters rows visibility
661
- self.components["control_params_row"],
662
- self.components["control_settings_row"],
663
- self.components["frame_conditioning_row"],
664
- self.components["control_options_row"],
665
- # Control parameter values
666
- self.components["control_type"],
667
- self.components["train_qk_norm"],
668
- self.components["frame_conditioning_type"],
669
- self.components["frame_conditioning_index"],
670
- self.components["frame_conditioning_concatenate_mask"],
671
- ]
672
  )
673
 
674
  # Training control events
675
  self.components["start_btn"].click(
676
  fn=self.handle_new_training_start,
677
  inputs=[
678
- self.components["training_preset"],
679
  self.components["model_type"],
680
  self.components["model_version"],
681
  self.components["training_type"],
@@ -696,7 +685,6 @@ class TrainTab(BaseTab):
696
  self.components["resume_btn"].click(
697
  fn=self.handle_resume_training,
698
  inputs=[
699
- self.components["training_preset"],
700
  self.components["model_type"],
701
  self.components["model_version"],
702
  self.components["training_type"],
@@ -761,23 +749,25 @@ class TrainTab(BaseTab):
761
  # Update UI state with proper model_type first
762
  self.app.update_ui_state(model_type=model_type)
763
 
764
- # Ensure model_versions is a simple list of strings
765
- model_versions = [str(version) for version in model_versions]
 
766
 
767
  # Create a new dropdown with the updated choices
768
- if not model_versions:
769
  logger.warning(f"No model versions available for {model_type}, using empty list")
770
  # Return empty dropdown to avoid errors
771
  return gr.Dropdown(choices=[], value=None)
772
 
773
  # Ensure default_version is in model_versions
774
- if default_version not in model_versions and model_versions:
775
- default_version = model_versions[0]
 
776
  logger.info(f"Default version not in choices, using first available: {default_version}")
777
 
778
  # Return the updated dropdown
779
- logger.info(f"Returning dropdown with {len(model_versions)} choices")
780
- return gr.Dropdown(choices=model_versions, value=default_version)
781
  except Exception as e:
782
  # Log any exceptions for debugging
783
  logger.error(f"Error in update_model_versions: {str(e)}")
@@ -785,7 +775,7 @@ class TrainTab(BaseTab):
785
  return gr.Dropdown(choices=[], value=None)
786
 
787
  def handle_training_start(
788
- self, preset, model_type, model_version, training_type,
789
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
790
  save_iterations, repo_id,
791
  progress=gr.Progress(),
@@ -844,7 +834,6 @@ class TrainTab(BaseTab):
844
  learning_rate,
845
  save_iterations,
846
  repo_id,
847
- preset_name=preset,
848
  training_type=training_internal_type,
849
  model_version=model_version,
850
  resume_from_checkpoint=resume_from,
@@ -898,14 +887,14 @@ class TrainTab(BaseTab):
898
  # Add general information about the selected training type
899
  if training_type == "Full Finetune":
900
  finetune_info = """
901
- ## 🧠 Full Finetune Mode
902
 
903
- Full finetune mode trains all parameters of the model, requiring more VRAM but potentially enabling higher quality results.
904
 
905
- - Requires 20-50GB+ VRAM depending on model
906
- - Creates a complete standalone model (~8GB+ file size)
907
- - Recommended only for high-end GPUs (A100, H100, etc.)
908
- - Not recommended for the larger models like Hunyuan Video on consumer hardware
909
  """
910
  model_info = finetune_info + "\n\n" + model_info
911
 
@@ -925,6 +914,8 @@ class TrainTab(BaseTab):
925
  self.components["batch_size"]: params["batch_size"],
926
  self.components["learning_rate"]: params["learning_rate"],
927
  self.components["save_iterations"]: params["save_iterations"],
 
 
928
  self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
929
  self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
930
  self.components["control_params_row"]: gr.Row(visible=show_control_params),
@@ -936,11 +927,11 @@ class TrainTab(BaseTab):
936
  def get_model_info(self, model_type: str, training_type: str) -> str:
937
  """Get information about the selected model type and training method"""
938
  if model_type == "HunyuanVideo":
939
- base_info = """### HunyuanVideo
940
- - Required VRAM: ~48GB minimum
941
- - Recommended batch size: 1-2
942
- - Typical training time: 2-4 hours
943
- - Default resolution: 49x512x768"""
944
 
945
  if training_type == "LoRA Finetune":
946
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
@@ -952,10 +943,10 @@ class TrainTab(BaseTab):
952
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
953
 
954
  elif model_type == "LTX-Video":
955
- base_info = """### LTX-Video
956
- - Recommended batch size: 1-4
957
- - Typical training time: 1-3 hours
958
- - Default resolution: 49x512x768"""
959
 
960
  if training_type == "LoRA Finetune":
961
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
@@ -967,10 +958,10 @@ class TrainTab(BaseTab):
967
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
968
 
969
  elif model_type == "Wan":
970
- base_info = """### Wan
971
- - Recommended batch size: 1-4
972
- - Typical training time: 1-3 hours
973
- - Default resolution: 49x512x768"""
974
 
975
  if training_type == "LoRA Finetune":
976
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
@@ -986,168 +977,30 @@ class TrainTab(BaseTab):
986
 
987
  def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
988
  """Get default training parameters for model type"""
989
- # Find preset that matches model type and training type
990
- matching_presets = [
991
- preset for preset_name, preset in TRAINING_PRESETS.items()
992
- if preset["model_type"] == model_type and preset["training_type"] == training_type
993
- ]
994
-
995
- if matching_presets:
996
- # Use the first matching preset
997
- preset = matching_presets[0]
998
- return {
999
- "train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
1000
- "batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE),
1001
- "learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE),
1002
- "save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1003
- "lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR),
1004
- "lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
1005
- }
1006
-
1007
- # Default fallbacks
1008
- if model_type == "hunyuan_video":
1009
- return {
1010
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
1011
- "batch_size": DEFAULT_BATCH_SIZE,
1012
- "learning_rate": 2e-5,
1013
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
1014
- "lora_rank": DEFAULT_LORA_RANK_STR,
1015
- "lora_alpha": DEFAULT_LORA_ALPHA_STR
1016
- }
1017
- elif model_type == "ltx_video":
1018
- return {
1019
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
1020
- "batch_size": DEFAULT_BATCH_SIZE,
1021
- "learning_rate": DEFAULT_LEARNING_RATE,
1022
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
1023
- "lora_rank": DEFAULT_LORA_RANK_STR,
1024
- "lora_alpha": DEFAULT_LORA_ALPHA_STR
1025
- }
1026
- elif model_type == "wan":
1027
- return {
1028
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
1029
- "batch_size": DEFAULT_BATCH_SIZE,
1030
- "learning_rate": 5e-5,
1031
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
1032
- "lora_rank": "32",
1033
- "lora_alpha": "32"
1034
- }
1035
- else:
1036
- # Generic defaults
1037
- return {
1038
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
1039
- "batch_size": DEFAULT_BATCH_SIZE,
1040
- "learning_rate": DEFAULT_LEARNING_RATE,
1041
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
1042
- "lora_rank": DEFAULT_LORA_RANK_STR,
1043
- "lora_alpha": DEFAULT_LORA_ALPHA_STR
1044
- }
1045
-
1046
- def update_training_params(self, preset_name: str) -> Tuple:
1047
- """Update UI components based on selected preset while preserving custom settings"""
1048
- preset = TRAINING_PRESETS[preset_name]
1049
-
1050
- # Load current UI state to check if user has customized values
1051
- current_state = self.app.load_ui_values()
1052
-
1053
- # Find the display name that maps to our model type
1054
- model_display_name = next(
1055
- key for key, value in MODEL_TYPES.items()
1056
- if value == preset["model_type"]
1057
- )
1058
-
1059
- # Find the display name that maps to our training type
1060
- training_display_name = next(
1061
- key for key, value in TRAINING_TYPES.items()
1062
- if value == preset["training_type"]
1063
- )
1064
-
1065
- # Get preset description for display
1066
- description = preset.get("description", "")
1067
-
1068
- # Get max values from buckets
1069
- buckets = preset["training_buckets"]
1070
- max_frames = max(frames for frames, _, _ in buckets)
1071
- max_height = max(height for _, height, _ in buckets)
1072
- max_width = max(width for _, _, width in buckets)
1073
- bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
1074
-
1075
- info_text = f"{description}{bucket_info}"
1076
 
1077
- # Check if LoRA params should be visible
1078
- training_type_internal = preset["training_type"]
1079
- show_lora_params = training_type_internal == "lora" or training_type_internal == "control-lora"
1080
-
1081
- # Check if Control params should be visible
1082
- show_control_params = training_type_internal == "control-lora" or training_type_internal == "control-full-finetune"
1083
 
1084
- # Use preset defaults but preserve user-modified values if they exist
1085
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
1086
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) else preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
1087
- train_steps_val = current_state.get("train_steps") if current_state.get("train_steps") != preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) else preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS)
1088
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
1089
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
1090
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
1091
- num_gpus_val = current_state.get("num_gpus") if current_state.get("num_gpus") != preset.get("num_gpus", DEFAULT_NUM_GPUS) else preset.get("num_gpus", DEFAULT_NUM_GPUS)
1092
- precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
1093
- lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
1094
 
1095
- # Control parameters
1096
- control_type_val = current_state.get("control_type") if current_state.get("control_type") != preset.get("control_type", DEFAULT_CONTROL_TYPE) else preset.get("control_type", DEFAULT_CONTROL_TYPE)
1097
- train_qk_norm_val = current_state.get("train_qk_norm") if current_state.get("train_qk_norm") != preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM) else preset.get("train_qk_norm", DEFAULT_TRAIN_QK_NORM)
1098
- frame_conditioning_type_val = current_state.get("frame_conditioning_type") if current_state.get("frame_conditioning_type") != preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE) else preset.get("frame_conditioning_type", DEFAULT_FRAME_CONDITIONING_TYPE)
1099
- frame_conditioning_index_val = current_state.get("frame_conditioning_index") if current_state.get("frame_conditioning_index") != preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX) else preset.get("frame_conditioning_index", DEFAULT_FRAME_CONDITIONING_INDEX)
1100
- frame_conditioning_concatenate_mask_val = current_state.get("frame_conditioning_concatenate_mask") if current_state.get("frame_conditioning_concatenate_mask") != preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK) else preset.get("frame_conditioning_concatenate_mask", DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK)
1101
 
1102
- # Get the appropriate model version for the selected model type
1103
- model_versions = self.get_model_version_choices(model_display_name)
1104
- default_model_version = self.get_default_model_version(model_display_name)
1105
-
1106
- # Ensure we have valid choices and values
1107
- if not model_versions:
1108
- logger.warning(f"No versions found for {model_display_name}, using empty list")
1109
- model_versions = []
1110
- default_model_version = None
1111
- elif default_model_version not in model_versions and model_versions:
1112
- default_model_version = model_versions[0]
1113
- logger.info(f"Reset default version to first available: {default_model_version}")
1114
-
1115
- # Ensure model_versions is a simple list of strings
1116
- model_versions = [str(version) for version in model_versions]
1117
-
1118
- # Create the model version dropdown update
1119
- model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
1120
-
1121
- # Return values in the same order as the output components listed in line 644
1122
- # Make sure we return exactly 24 values to match what's expected
1123
- return (
1124
- model_display_name, # model_type
1125
- training_display_name, # training_type
1126
- lora_rank_val, # lora_rank
1127
- lora_alpha_val, # lora_alpha
1128
- train_steps_val, # train_steps
1129
- batch_size_val, # batch_size
1130
- learning_rate_val, # learning_rate
1131
- save_iterations_val, # save_iterations
1132
- info_text, # preset_info
1133
- gr.Row(visible=show_lora_params), # lora_params_row
1134
- gr.Row(visible=show_lora_params), # lora_settings_row (added missing row)
1135
- num_gpus_val, # num_gpus
1136
- precomputation_items_val, # precomputation_items
1137
- lr_warmup_steps_val, # lr_warmup_steps
1138
- model_version_update, # model_version
1139
- # Control parameters rows visibility
1140
- gr.Row(visible=show_control_params), # control_params_row
1141
- gr.Row(visible=show_control_params), # control_settings_row
1142
- gr.Row(visible=show_control_params), # frame_conditioning_row
1143
- gr.Row(visible=show_control_params), # control_options_row
1144
- # Control parameter values
1145
- control_type_val, # control_type
1146
- train_qk_norm_val, # train_qk_norm
1147
- frame_conditioning_type_val, # frame_conditioning_type
1148
- frame_conditioning_index_val, # frame_conditioning_index
1149
- frame_conditioning_concatenate_mask_val, # frame_conditioning_concatenate_mask
1150
- )
1151
 
1152
 
1153
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
 
13
  from vms.utils import BaseTab
14
  from vms.config import (
15
  ASK_USER_TO_DUPLICATE_SPACE,
16
+ SD_TRAINING_BUCKETS, MD_TRAINING_BUCKETS,
17
+ RESOLUTION_OPTIONS,
18
+ TRAINING_TYPES, MODEL_TYPES, MODEL_VERSIONS,
19
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
20
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
21
  DEFAULT_LEARNING_RATE,
 
30
  DEFAULT_AUTO_RESUME,
31
  DEFAULT_CONTROL_TYPE, DEFAULT_TRAIN_QK_NORM,
32
  DEFAULT_FRAME_CONDITIONING_TYPE, DEFAULT_FRAME_CONDITIONING_INDEX,
33
+ DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK,
34
+ HUNYUAN_VIDEO_DEFAULTS, LTX_VIDEO_DEFAULTS, WAN_DEFAULTS
35
  )
36
 
37
  logger = logging.getLogger(__name__)
 
52
  with gr.Row():
53
  self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
54
 
 
 
 
 
 
 
 
 
 
55
  with gr.Row():
56
  with gr.Column():
57
  # Get the default model type from the first preset
 
108
  self.components["model_info"] = gr.Markdown(
109
  value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0])
110
  )
111
+
112
+ with gr.Row():
113
+ with gr.Column():
114
+ self.components["resolution"] = gr.Dropdown(
115
+ choices=list(RESOLUTION_OPTIONS.keys()),
116
+ label="Resolution",
117
+ value=list(RESOLUTION_OPTIONS.keys())[0],
118
+ info="Select the resolution for training videos"
119
+ )
120
 
121
  # LoRA specific parameters (will show/hide based on training type)
122
  with gr.Row(visible=True) as lora_params_row:
 
142
 
143
  with gr.Accordion("What is LoRA Rank?", open=False):
144
  gr.Markdown("""
145
+ **LoRA Rank** determines the complexity of the LoRA adapters:
146
+
147
+ - **Lower rank (16-32)**: Smaller file size, faster training, but less expressive
148
+ - **Medium rank (64-128)**: Good balance between quality and file size
149
+ - **Higher rank (256-1024)**: More expressive adapters, better quality but larger file size
150
+
151
+ Think of rank as the "capacity" of your adapter. Higher ranks can learn more complex modifications to the base model but require more VRAM during training and result in larger files.
152
+
153
+ **Quick guide:**
154
+ - For Wan models: Use 32-64 (Wan models work well with lower ranks)
155
+ - For LTX-Video: Use 128-256
156
+ - For Hunyuan Video: Use 128
157
  """)
158
 
159
  with gr.Column():
 
164
  type="value",
165
  info="Controls the effective learning rate scaling of LoRA adapters. Usually set to same value as rank"
166
  )
 
167
  with gr.Accordion("What is LoRA Alpha?", open=False):
168
  gr.Markdown("""
169
+ **LoRA Alpha** controls the effective scale of the LoRA updates:
170
+
171
+ - The actual scaling factor is calculated as `alpha ÷ rank`
172
+ - Usually set to match the rank value (alpha = rank)
173
+ - Higher alpha = stronger effect from the adapters
174
+ - Lower alpha = more subtle adapter influence
175
+
176
+ **Best practice:**
177
+ - For most cases, set alpha equal to rank
178
+ - For more aggressive training, set alpha higher than rank
179
+ - For more conservative training, set alpha lower than rank
180
  """)
181
+
182
 
183
  # Control specific parameters (will show/hide based on training type)
184
  with gr.Row(visible=False) as control_params_row:
185
  self.components["control_params_row"] = control_params_row
186
  with gr.Column():
187
  gr.Markdown("""
188
+ ## 🖼️ Control Training Settings
189
+
190
+ Control training enables **image-to-video generation** by teaching the model how to use an image as a guide for video creation.
191
+ This is ideal for turning still images into dynamic videos while preserving composition, style, and content.
192
  """)
193
 
194
  # Second row for control parameters
 
204
 
205
  with gr.Accordion("What is Control Conditioning?", open=False):
206
  gr.Markdown("""
207
+ **Control Conditioning** allows the model to be guided by an input image, adapting the video generation based on the image content. This is used for image-to-video generation where you want to turn an image into a moving video while maintaining its style, composition or content.
208
+
209
+ - **canny**: Uses edge detection to extract outlines from images for structure-preserving video generation
210
+ - **custom**: Direct image conditioning without preprocessing, preserving more image details
211
  """)
212
 
213
  with gr.Column():
 
219
 
220
  with gr.Accordion("What is QK Normalization?", open=False):
221
  gr.Markdown("""
222
+ **QK Normalization** refers to normalizing the query and key values in the attention mechanism of transformers.
223
+
224
+ - When enabled, allows the model to better integrate control signals with content generation
225
+ - Improves training stability for control models
226
+ - Generally recommended for control training, especially with image conditioning
227
  """)
228
 
229
  with gr.Row(visible=False) as frame_conditioning_row:
 
238
 
239
  with gr.Accordion("Frame Conditioning Type Explanation", open=False):
240
  gr.Markdown("""
241
+ **Frame Conditioning Types** determine which frames in the video receive image conditioning:
242
+
243
+ - **index**: Only applies conditioning to a single frame at the specified index
244
+ - **prefix**: Applies conditioning to all frames before a certain point
245
+ - **random**: Randomly selects frames to receive conditioning during training
246
+ - **first_and_last**: Only applies conditioning to the first and last frames
247
+ - **full**: Applies conditioning to all frames in the video
248
+
249
+ For image-to-video tasks, 'index' (usually with index 0) is most common as it conditions only the first frame.
250
  """)
251
 
252
  with gr.Column():
 
268
 
269
  with gr.Accordion("What is Frame Mask Concatenation?", open=False):
270
  gr.Markdown("""
271
+ **Frame Mask Concatenation** adds an additional channel to the control signal that indicates which frames are being conditioned:
272
 
273
+ - Creates a binary mask (0/1) indicating which frames receive conditioning
274
+ - Helps the model distinguish between conditioned and unconditioned frames
275
+ - Particularly useful for 'index' conditioning where only select frames are conditioned
276
+ - Generally improves temporal consistency between conditioned and unconditioned frames
277
  """)
278
 
279
  with gr.Column():
 
449
  return None
450
 
451
  def handle_new_training_start(
452
+ self, model_type, model_version, training_type,
453
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
454
  save_iterations, repo_id, progress=gr.Progress()
455
  ):
 
470
 
471
  # Start training normally
472
  return self.handle_training_start(
473
+ model_type, model_version, training_type,
474
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
475
  save_iterations, repo_id, progress
476
  )
477
 
478
  def handle_resume_training(
479
+ self, model_type, model_version, training_type,
480
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
481
  save_iterations, repo_id, progress=gr.Progress()
482
  ):
 
491
 
492
  # Start training with the checkpoint
493
  return self.handle_training_start(
494
+ model_type, model_version, training_type,
495
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
496
  save_iterations, repo_id, progress,
497
  resume_from_checkpoint="latest"
 
499
 
500
  def connect_events(self) -> None:
501
  """Connect event handlers to UI components"""
502
+ # Model type change event - Update model version dropdown choices and default parameters
503
  self.components["model_type"].change(
504
  fn=self.update_model_versions,
505
  inputs=[self.components["model_type"]],
506
  outputs=[self.components["model_version"]]
507
  ).then(
508
+ fn=self.update_model_type_and_version,
509
  inputs=[self.components["model_type"], self.components["model_version"]],
510
  outputs=[]
511
  ).then(
512
+ # Update model info and recommended default values based on model and training type
513
+ fn=self.update_model_info,
514
  inputs=[self.components["model_type"], self.components["training_type"]],
515
+ outputs=[
516
+ self.components["model_info"],
517
+ self.components["train_steps"],
518
+ self.components["batch_size"],
519
+ self.components["learning_rate"],
520
+ self.components["save_iterations"],
521
+ self.components["lora_params_row"],
522
+ self.components["lora_settings_row"],
523
+ self.components["control_params_row"],
524
+ self.components["control_settings_row"],
525
+ self.components["frame_conditioning_row"],
526
+ self.components["control_options_row"],
527
+ self.components["lora_rank"],
528
+ self.components["lora_alpha"]
529
+ ]
530
  )
531
 
532
  # Model version change event
 
550
  self.components["batch_size"],
551
  self.components["learning_rate"],
552
  self.components["save_iterations"],
553
+ self.components["lora_params_row"],
554
+ self.components["lora_settings_row"],
555
+ self.components["control_params_row"],
556
+ self.components["control_settings_row"],
557
+ self.components["frame_conditioning_row"],
558
+ self.components["control_options_row"],
559
+ self.components["lora_rank"],
560
+ self.components["lora_alpha"]
561
  ]
562
  )
563
 
 
654
  outputs=[]
655
  )
656
 
657
+ # Resolution change event
658
+ self.components["resolution"].change(
659
+ fn=lambda v: self.app.update_ui_state(resolution=v),
660
+ inputs=[self.components["resolution"]],
661
  outputs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  )
663
 
664
  # Training control events
665
  self.components["start_btn"].click(
666
  fn=self.handle_new_training_start,
667
  inputs=[
 
668
  self.components["model_type"],
669
  self.components["model_version"],
670
  self.components["training_type"],
 
685
  self.components["resume_btn"].click(
686
  fn=self.handle_resume_training,
687
  inputs=[
 
688
  self.components["model_type"],
689
  self.components["model_version"],
690
  self.components["training_type"],
 
749
  # Update UI state with proper model_type first
750
  self.app.update_ui_state(model_type=model_type)
751
 
752
+ # Create a list of tuples (label, value) for the dropdown choices
753
+ # This ensures compatibility with Gradio Dropdown component expectations
754
+ choices_tuples = [(str(version), str(version)) for version in model_versions]
755
 
756
  # Create a new dropdown with the updated choices
757
+ if not choices_tuples:
758
  logger.warning(f"No model versions available for {model_type}, using empty list")
759
  # Return empty dropdown to avoid errors
760
  return gr.Dropdown(choices=[], value=None)
761
 
762
  # Ensure default_version is in model_versions
763
+ string_versions = [str(v) for v in model_versions]
764
+ if default_version not in string_versions and string_versions:
765
+ default_version = string_versions[0]
766
  logger.info(f"Default version not in choices, using first available: {default_version}")
767
 
768
  # Return the updated dropdown
769
+ logger.info(f"Returning dropdown with {len(choices_tuples)} choices")
770
+ return gr.Dropdown(choices=choices_tuples, value=default_version)
771
  except Exception as e:
772
  # Log any exceptions for debugging
773
  logger.error(f"Error in update_model_versions: {str(e)}")
 
775
  return gr.Dropdown(choices=[], value=None)
776
 
777
  def handle_training_start(
778
+ self, model_type, model_version, training_type,
779
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
780
  save_iterations, repo_id,
781
  progress=gr.Progress(),
 
834
  learning_rate,
835
  save_iterations,
836
  repo_id,
 
837
  training_type=training_internal_type,
838
  model_version=model_version,
839
  resume_from_checkpoint=resume_from,
 
887
  # Add general information about the selected training type
888
  if training_type == "Full Finetune":
889
  finetune_info = """
890
+ ## 🧠 Full Finetune Mode
891
 
892
+ Full finetune mode trains all parameters of the model, requiring more VRAM but potentially enabling higher quality results.
893
 
894
+ - Requires 20-50GB+ VRAM depending on model
895
+ - Creates a complete standalone model (~8GB+ file size)
896
+ - Recommended only for high-end GPUs (A100, H100, etc.)
897
+ - Not recommended for the larger models like Hunyuan Video on consumer hardware
898
  """
899
  model_info = finetune_info + "\n\n" + model_info
900
 
 
914
  self.components["batch_size"]: params["batch_size"],
915
  self.components["learning_rate"]: params["learning_rate"],
916
  self.components["save_iterations"]: params["save_iterations"],
917
+ self.components["lora_rank"]: params["lora_rank"],
918
+ self.components["lora_alpha"]: params["lora_alpha"],
919
  self.components["lora_params_row"]: gr.Row(visible=show_lora_params),
920
  self.components["lora_settings_row"]: gr.Row(visible=show_lora_params),
921
  self.components["control_params_row"]: gr.Row(visible=show_control_params),
 
927
  def get_model_info(self, model_type: str, training_type: str) -> str:
928
  """Get information about the selected model type and training method"""
929
  if model_type == "HunyuanVideo":
930
+ base_info = """## HunyuanVideo Training
931
+ - Required VRAM: ~48GB minimum
932
+ - Recommended batch size: 1-2
933
+ - Typical training time: 2-4 hours
934
+ - Default resolution: 49x512x768"""
935
 
936
  if training_type == "LoRA Finetune":
937
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
 
943
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
944
 
945
  elif model_type == "LTX-Video":
946
+ base_info = """## LTX-Video Training
947
+ - Recommended batch size: 1-4
948
+ - Typical training time: 1-3 hours
949
+ - Default resolution: 49x512x768"""
950
 
951
  if training_type == "LoRA Finetune":
952
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
 
958
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
959
 
960
  elif model_type == "Wan":
961
+ base_info = """## Wan2.1 Training
962
+ - Recommended batch size: 1-4
963
+ - Typical training time: 1-3 hours
964
+ - Default resolution: 49x512x768"""
965
 
966
  if training_type == "LoRA Finetune":
967
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
 
977
 
978
  def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
979
  """Get default training parameters for model type"""
980
+ # Use model-specific defaults based on model_type and training_type
981
+ model_defaults = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
 
983
+ if model_type == "hunyuan_video" and training_type in HUNYUAN_VIDEO_DEFAULTS:
984
+ model_defaults = HUNYUAN_VIDEO_DEFAULTS[training_type]
985
+ elif model_type == "ltx_video" and training_type in LTX_VIDEO_DEFAULTS:
986
+ model_defaults = LTX_VIDEO_DEFAULTS[training_type]
987
+ elif model_type == "wan" and training_type in WAN_DEFAULTS:
988
+ model_defaults = WAN_DEFAULTS[training_type]
989
 
990
+ # Build the complete params dict with defaults plus model-specific overrides
991
+ params = {
992
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
993
+ "batch_size": DEFAULT_BATCH_SIZE,
994
+ "learning_rate": DEFAULT_LEARNING_RATE,
995
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
996
+ "lora_rank": DEFAULT_LORA_RANK_STR,
997
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
998
+ }
 
999
 
1000
+ # Override with model-specific values
1001
+ params.update(model_defaults)
 
 
 
 
1002
 
1003
+ return params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
 
1006
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]: