Spaces:
Runtime error
Runtime error
| from transformers import pipeline | |
| from imgutils.data import rgb_encode, load_image | |
| from onnx_ import _open_onnx_model | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import requests | |
| import torch | |
| import json | |
| def _img_encode(image, size=(384,384), normalize=(0.5,0.5)): | |
| image = image.resize(size, Image.BILINEAR) | |
| data = rgb_encode(image, order_='CHW') | |
| if normalize is not None: | |
| mean_, std_ = normalize | |
| mean = np.asarray([mean_]).reshape((-1, 1, 1)) | |
| std = np.asarray([std_]).reshape((-1, 1, 1)) | |
| data = (data - mean) / std | |
| return data.astype(np.float32) | |
| nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") | |
| if not os.path.exists("timm.onnx"): | |
| open("timm.onnx", "wb").write( | |
| requests.get( | |
| "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx" | |
| ).content | |
| ) | |
| open("timmcfg.json", "wb").write( | |
| requests.get( | |
| "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" | |
| ).content | |
| ) | |
| else: | |
| print("Model already exists, skipping redownload") | |
| with open("timmcfg.json") as file: | |
| tm_cfg = json.load(file) | |
| nsfw_tm = _open_onnx_model("timm.onnx") | |
| def launch(img): | |
| weight = 0 | |
| img = img.convert('RGB') | |
| tm_image = load_image(img, mode='RGB') | |
| tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...] | |
| tm_items, = nsfw_tm.run(['output'], {'input': tm_input_}) | |
| tm_output = sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True)[0][0] | |
| match tm_output: | |
| case "safe": | |
| weight -= 1 | |
| case "r15": | |
| weight += 2 | |
| case "r18": | |
| weight += 2 | |
| tf_output = nsfw_tf(img)[0]["label"] | |
| match tf_output: | |
| case "safe": | |
| weight -= 1 | |
| case "suggestive": | |
| weight += 1 | |
| case "r18": | |
| weight += 2 | |
| print(sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True), tf_output) | |
| return weight > 0 | |
| app = gr.Interface(fn=launch, inputs="pil", outputs="text") | |
| app.launch() |