import json import os from collections import defaultdict from typing import Dict, List import faiss import gradio as gr import numpy as np from cheesechaser.datapool import ( AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, GelbooruWebpDataPool, KonachanWebpDataPool, Rule34WebpDataPool, YandeWebpDataPool, ZerochanWebpDataPool, ) from hfutils.operate import get_hf_client, get_hf_fs from hfutils.utils import TemporaryDirectory from imgutils.generic import siglip from imgutils.utils import ts_lru_cache from PIL import Image from pools import quick_webp_pool _SIGLIP_REPO_ID = "deepghs/siglip_beta" _INDEX_REPO_ID = 'deepghs/anime_sites_indices' hf_fs = get_hf_fs() hf_client = get_hf_client() _DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_8005009_4GB' _ALL_MODEL_NAMES = [ os.path.dirname(os.path.relpath(path, _INDEX_REPO_ID)) for path in hf_fs.glob(f'{_INDEX_REPO_ID}/*/knn.index') ] _SITE_CLS = { 'danbooru': DanbooruNewestWebpDataPool, 'yandere': YandeWebpDataPool, 'zerochan': ZerochanWebpDataPool, 'gelbooru': GelbooruWebpDataPool, 'konachan': KonachanWebpDataPool, 'anime_pictures': AnimePicturesWebpDataPool, 'rule34': Rule34WebpDataPool, } def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: with TemporaryDirectory() as td: site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3) datapool = site_cls() datapool.batch_download_to_directory( resource_ids=ids, dst_dir=td, ) retval = {} for file in os.listdir(td): id_ = int(os.path.splitext(file)[0]) image = Image.open(os.path.join(td, file)) image.load() retval[id_] = image return retval def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: _sites = defaultdict(list) for id_ in ids: site_name, num_id = id_.rsplit('_', maxsplit=1) num_id = int(num_id) _sites[site_name].append(num_id) _retval = {} for site_name, site_ids in _sites.items(): _retval.update({ f'{site_name}_{id_}': image for id_, image in _get_from_ids(site_name, site_ids).items() }) return _retval @ts_lru_cache(maxsize=3) def _get_index_info(repo_id: str, model_name: str): image_ids = np.load(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/ids.npy', )) knn_index = faiss.read_index(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/knn.index', )) config = json.loads(open(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/infos.json', )).read())["index_param"] faiss.ParameterSpace().set_index_parameters(knn_index, config) return image_ids, knn_index def search(model_name: str, img_input, str_input: str, n_neighbours: int): images_ids, knn_index = _get_index_info(_INDEX_REPO_ID, model_name) if str_input == "": embeddings = siglip.siglip_image_encode( img_input, repo_id=_SIGLIP_REPO_ID, model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s", fmt="embeddings", ) else: embeddings = siglip.siglip_text_encode( str_input, repo_id=_SIGLIP_REPO_ID, model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s", fmt="embeddings", ) # In the model, the "embeddings" output node is already normalized. # Ask for the "encodings" output if you want the raw logits dists, indexes = knn_index.search(embeddings, k=n_neighbours) neighbours_ids = images_ids[indexes][0] captions = [] images = [] ids_to_images = _get_from_raw_ids(neighbours_ids) for image_id, dist in zip(neighbours_ids, dists[0]): if image_id in ids_to_images: images.append(ids_to_images[image_id]) captions.append(f"{image_id}/{dist:.2f}") return list(zip(images, captions)) if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", image_mode="RGBA", label="Image input") str_input = gr.Textbox(label="Text input (leave empty to use image input)") with gr.Column(): with gr.Row(): n_model = gr.Dropdown( choices=_ALL_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Index to Use', ) with gr.Row(): n_neighbours = gr.Slider( minimum=1, maximum=50, value=20, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") with gr.Row(): similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=search, inputs=[ n_model, img_input, str_input, n_neighbours, ], outputs=[similar_images], ) demo.queue().launch()