File size: 7,155 Bytes
83782c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import requests,re,base64,io,numpy as np
from PIL import Image,ImageOps
import torch,gradio as gr

# Custom CSS for gallery styling
css = """
#custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
#custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease,  box-shadow 0.2s ease;}
#custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
#custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
#custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
#custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
.gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
.thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
#custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
"""

# Helper to load image from URL
def loadImageFromUrl(url):
    response = requests.get(url, timeout=10)
    if response.status_code != 200:
        raise Exception(f"Failed to load image from {url}")
    i = Image.open(io.BytesIO(response.content))
    i = ImageOps.exif_transpose(i)
    if i.mode != "RGBA":
        i = i.convert("RGBA")
    alpha = i.split()[-1]
    image = Image.new("RGB", i.size, (0, 0, 0))
    image.paste(i, mask=alpha)
    image = np.array(image).astype(np.float32) / 255.0
    image = torch.from_numpy(image)[None,]
    return image

# Fetch data from Xbooru platform only
def fetch_booru_images(Tags, exclude_tags, score, count, Safe, Questionable, Explicit):
    # Clean and format tags
    def clean_tag_list(tags):
        return [item.strip().replace(' ', '_') for item in tags.split(',') if item.strip()]

    Tags = '+'.join(clean_tag_list(Tags)) if Tags else ''
    exclude_tags = '+'.join('-' + tag for tag in clean_tag_list(exclude_tags))

    rating_filters = []
    if not Safe:
        rating_filters.extend(["rating:safe", "rating:general"])
    if not Questionable:
        rating_filters.extend(["rating:questionable", "rating:sensitive"])
    if not Explicit:
        rating_filters.append("rating:explicit")
    rating_filters = '+'.join(f'-{r}' for r in rating_filters)

    score_filter = f"score:>{score}"

    # Build query
    base_query = f"tags=sort:random+{Tags}+{exclude_tags}+{score_filter}+{rating_filters}&limit={count}&json=1"
    base_query = re.sub(r"\++", "+", base_query)

    # Fetch data from Xbooru only atm
    url = f"https://xbooru.com/index.php?page=dapi&s=post&q=index&{base_query}"
    response = requests.get(url).json()
    posts = response

    # Extract image URLs, tags, and post URLs
    image_urls = []
    tags_list = [post.get("tags", "").replace(" ", ", ").replace("_", " ").replace("(", "\\(").replace(")", "\\)").strip() for post in posts]
    post_urls = []

    for post in posts:
        file_url = post.get("file_url")
        tags = post.get("tags", "").replace(" ", ", ").strip()
        post_id = post.get("id", "")

        if file_url:
            image_urls.append(file_url)
            tags_list.append(tags)
            post_urls.append(f"https://xbooru.com/index.php?page=post&s=view&id={post_id}")

    return image_urls, tags_list, post_urls

# Main function to fetch and return processed images
def booru_gradio(Tags, exclude_tags, score, count, Safe, Questionable, Explicit):
    image_urls, tags_list, post_urls = fetch_booru_images(Tags, exclude_tags, score, count, Safe, Questionable, Explicit)

    if not image_urls:
        return [], [], [], []

    image_data = []
    for url in image_urls:
        try:
            image = loadImageFromUrl(url)
            image = (image * 255).clamp(0, 255).cpu().numpy().astype(np.uint8)[0]
            image = Image.fromarray(image)
            image_data.append(image)
        except Exception as e:
            print(f"Error loading image from {url}: {e}")
            continue

    return image_data, tags_list, post_urls, image_urls

# Update UI on image click
def on_select(evt: gr.SelectData, tags_list, post_url_list, image_url_list):
    idx = evt.index
    if idx < len(tags_list):
        return tags_list[idx], post_url_list[idx], image_url_list[idx]
    return "No tags", "", ""

def create_booru_interface():
    with gr.Blocks(css=css, fill_width=True) as demo:
        with gr.Row():
                    with gr.Column():
                        gr.Markdown("### ⚙️ Search Parameters")
                        Tags = gr.Textbox(label="Tags (comma-separated)", placeholder="e.g. solo,  1girl,  1boy,  artist name,  character,  black hair,  granblue fantasy,  ...", lines=3)
                        exclude_tags = gr.Textbox(label="Exclude Tags (comma-separated)", placeholder="e.g. animated,  watermark,  username,  ...", lines=3)
                        score = gr.Number(label="Minimum Score", value=0)
                        count = gr.Slider(label="Number of Images", minimum=1, maximum=10, step=1, value=1)
                        Safe = gr.Checkbox(label="Include Safe", value=True)
                        Questionable = gr.Checkbox(label="Include Questionable", value=True)
                        Explicit = gr.Checkbox(label="Include Explicit (18+)", value=False)
                        submit_btn = gr.Button("Fetch Images", variant="primary")
                    with gr.Column():
                        gr.Markdown("### 📄 Results")
                        images_output = gr.Gallery(
                            columns=2,
                            show_share_button=False,
                            interactive=True,
                            height='auto',
                            label='Grid of images',
                            preview=False,
                            elem_id='custom-gallery'
                        )
                        tags_output = gr.Textbox(label="Tags", placeholder="Select an image to display tags", lines=6, show_copy_button=True)
                        post_url_output = gr.Textbox(label="Post URL", lines=2, show_copy_button=True)
                        image_url_output = gr.Textbox(label="Image URL", lines=2, show_copy_button=True)
                    # State to store tags, URLs
                    tags_state = gr.State([])
                    post_url_state = gr.State([])
                    image_url_state = gr.State([])
                    submit_btn.click(fn=booru_gradio, inputs=[Tags, exclude_tags, score, count, Safe, Questionable, Explicit], outputs=[images_output, tags_state, post_url_state, image_url_state], )
                    images_output.select(fn=on_select, inputs=[tags_state, post_url_state, image_url_state], outputs=[tags_output, post_url_output, image_url_output], )

    return demo