File size: 6,472 Bytes
26f9f8c
 
 
 
5f66166
1bc5792
1082e14
ab8a2ac
 
 
26f9f8c
83fcdda
 
 
 
 
 
 
26f9f8c
1bc5792
26f9f8c
3fdc5c7
 
 
297f6c5
 
3fdc5c7
e860a96
26f9f8c
 
 
 
ab8a2ac
 
 
5f66166
7322075
 
 
1082e14
9177c96
7322075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab8a2ac
26f9f8c
1bc5792
78afd2f
26f9f8c
 
 
 
1bc5792
26f9f8c
 
1bc5792
 
 
5f66166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bc5792
 
 
 
 
 
 
 
5f66166
1bc5792
5f66166
1bc5792
 
5f66166
 
1bc5792
 
5f66166
1bc5792
 
 
 
 
 
 
 
 
5f66166
1bc5792
 
ab8a2ac
 
 
 
 
 
 
 
 
1bc5792
ab8a2ac
 
 
 
1082e14
 
 
 
 
 
 
ab8a2ac
 
 
1082e14
 
 
325abfd
1bc5792
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import numpy as np
import random
import torch
import spaces
from diffusers import DiffusionPipeline

from tags_straight import TAGS_STRAIGHT
from tags_lesbian import TAGS_LESBIAN
from tags_gay import TAGS_GAY

PROMPT_PREFIXES = {
    "Prompt Input": "score_9, score_8_up, score_7_up, source_anime",
    "Straight": "score_9, score_8_up, score_7_up, source_anime, ",
    "Lesbian": "score_9, score_8_up, score_7_up, source_anime, ",
    "Gay": "score_9, score_8_up, score_7_up, source_anime, yaoi, "
}

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32



# model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v140-sdxl"
# model_repo_id = "John6666/pony-realism-v23-ultra-sdxl"

pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def create_checkboxes(tag_dict, suffix):
    categories = list(tag_dict.keys())
    return [gr.CheckboxGroup(choices=list(tag_dict[cat].keys()), label=f"{cat} Tags ({suffix})") for cat in categories], categories

straight_checkboxes, _ = create_checkboxes(TAGS_STRAIGHT, "Straight")
lesbian_checkboxes, _ = create_checkboxes(TAGS_LESBIAN, "Lesbian")
gay_checkboxes, _ = create_checkboxes(TAGS_GAY, "Gay")

@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
          guidance_scale, num_inference_steps, active_tab, *tag_selections,
          progress=gr.Progress(track_tqdm=True)):

    prefix = PROMPT_PREFIXES.get(active_tab, "score_9, score_8_up, score_7_up, source_anime")

    if active_tab == "Prompt Input":
        final_prompt = f"{prefix}, {prompt}"
    else:
        combined_tags = []

        straight_len = len(TAGS_STRAIGHT)
        lesbian_len = len(TAGS_LESBIAN)
        gay_len = len(TAGS_GAY)

        if active_tab == "Straight":
            for (tag_name, tag_dict), selected in zip(TAGS_STRAIGHT.items(), tag_selections[:straight_len]):
                combined_tags.extend([tag_dict[tag] for tag in selected])
        elif active_tab == "Lesbian":
            offset = straight_len
            for (tag_name, tag_dict), selected in zip(TAGS_LESBIAN.items(), tag_selections[offset:offset+lesbian_len]):
                combined_tags.extend([tag_dict[tag] for tag in selected])
        elif active_tab == "Gay":
            offset = straight_len + lesbian_len
            for (tag_name, tag_dict), selected in zip(TAGS_GAY.items(), tag_selections[offset:offset+gay_len]):
                combined_tags.extend([tag_dict[tag] for tag in selected])

        tag_string = ", ".join(combined_tags)
        final_prompt = f"{prefix} {tag_string}"

    negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
    full_negative_prompt = f"{negative_base}, {negative_prompt}"

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=final_prompt,
        negative_prompt=full_negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator
    ).images[0]

    return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"


css = """
#col-container {
    margin: 0 auto;
    max-width: 1280px;
}

#left-column {
    width: 50%;
    display: inline-block;
    padding: 20px;
    vertical-align: top;
}

#right-column {
    width: 50%;
    display: inline-block;
    vertical-align: top;
    padding: 20px;
    margin-top: 53px;
}

#run-button {
    width: 100%;
    margin-top: 10px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column(elem_id="left-column"):
            gr.Markdown("# Rainbow Media X")

            result = gr.Image(label="Result", show_label=False)
            prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)

            with gr.Accordion("Advanced Settings", open=False):
                negative_prompt = gr.Textbox(label="Negative prompt", placeholder="Enter negative prompt")
                seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
                    height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)

                with gr.Row():
                    guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=10, step=0.1, value=7)
                    num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=35)

            run_button = gr.Button("Run", elem_id="run-button")

        with gr.Column(elem_id="right-column"):
            active_tab = gr.State("Prompt Input")

            with gr.Tabs() as tabs:
                with gr.TabItem("Prompt Input") as prompt_tab:
                    prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt")
                    prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)

                with gr.TabItem("Straight") as straight_tab:
                    for cb in straight_checkboxes:
                        cb.render()
                    straight_tab.select(lambda: "Straight", outputs=active_tab)

                with gr.TabItem("Lesbian") as lesbian_tab:
                    for cb in lesbian_checkboxes:
                        cb.render()
                    lesbian_tab.select(lambda: "Lesbian", outputs=active_tab)

                with gr.TabItem("Gay") as gay_tab:
                    for cb in gay_checkboxes:
                        cb.render()
                    gay_tab.select(lambda: "Gay", outputs=active_tab)

    run_button.click(
        fn=infer,
        inputs=[
            prompt, negative_prompt, seed, randomize_seed,
            width, height, guidance_scale, num_inference_steps,
            active_tab,
            *straight_checkboxes,
            *lesbian_checkboxes,
            *gay_checkboxes
        ],
        outputs=[result, seed, prompt_info]
    )

demo.queue().launch()