File size: 3,846 Bytes
2e82449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr

from modules import scripts
from backend.misc.image_resize import adaptive_resize


class PatchModelAddDownscale:
    def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
        sigma_start = model.model.predictor.percent_to_sigma(start_percent)
        sigma_end = model.model.predictor.percent_to_sigma(end_percent)

        def input_block_patch(h, transformer_options):
            if transformer_options["block"][1] == block_number:
                sigma = transformer_options["sigmas"][0].item()
                if sigma <= sigma_start and sigma >= sigma_end:
                    h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
            return h

        def output_block_patch(h, hsp, transformer_options):
            if h.shape[2] != hsp.shape[2]:
                h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
            return h, hsp

        m = model.clone()
        if downscale_after_skip:
            m.set_model_input_block_patch_after_skip(input_block_patch)
        else:
            m.set_model_input_block_patch(input_block_patch)
        m.set_model_output_block_patch(output_block_patch)
        return (m,)


opPatchModelAddDownscale = PatchModelAddDownscale()


class KohyaHRFixForForge(scripts.Script):
    sorting_priority = 14

    def title(self):
        return "Kohya HRFix Integrated"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, *args, **kwargs):
        upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
        with gr.Accordion(open=False, label=self.title()):
            enabled = gr.Checkbox(label='Enabled', value=False)
            block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
            downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
            start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
            end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
            downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True)
            downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
            upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])

        return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method

    def process_before_every_sampling(self, p, *script_args, **kwargs):
        enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method = script_args
        block_number = int(block_number)

        if not enabled:
            return

        unet = p.sd_model.forge_objects.unet

        unet = opPatchModelAddDownscale.patch(unet, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method)[0]

        p.sd_model.forge_objects.unet = unet

        p.extra_generation_params.update(dict(
            kohya_hrfix_enabled=enabled,
            kohya_hrfix_block_number=block_number,
            kohya_hrfix_downscale_factor=downscale_factor,
            kohya_hrfix_start_percent=start_percent,
            kohya_hrfix_end_percent=end_percent,
            kohya_hrfix_downscale_after_skip=downscale_after_skip,
            kohya_hrfix_downscale_method=downscale_method,
            kohya_hrfix_upscale_method=upscale_method,
        ))

        return