|
import os |
|
import torch |
|
from PIL import Image |
|
import streamlit as st |
|
from pathlib import Path |
|
import io |
|
|
|
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): |
|
|
|
if file and suffix: |
|
if isinstance(suffix, str): |
|
suffix = [suffix] |
|
for f in file if isinstance(file, (list, tuple)) else [file]: |
|
s = Path(f).suffix.lower() |
|
if len(s): |
|
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" |
|
|
|
|
|
def read_pretrain(path): |
|
return torch.hub.load('ultralytics/yolov5', 'custom', path=path) |
|
|
|
|
|
st.title("Hololive Waifu Classification") |
|
|
|
uploaded_file = st.file_uploader( |
|
"画像ファイルをアップロードしてください (対応形式: JPG, JPEG, PNG, WEBP)", |
|
type=["jpg", "jpeg", "png", "webp"] |
|
) |
|
|
|
pretrained = st.selectbox( |
|
'Select pre-trained', |
|
('2022.11.04-YOLOv5x6_1280-Hololive_Waifu_Classification.pt', |
|
'2022.11.01-YOLOv5x6_1280-Hololive_Waifu_Classification.pt') |
|
) |
|
imgsz = st.number_input(label='Image Size', min_value=None, max_value=None, value=1280, step=1) |
|
conf = st.slider(label='Confidence threshold', min_value=0.0, max_value=1.0, value=0.25, step=0.01) |
|
iou = st.slider(label='IoU threshold', min_value=0.0, max_value=1.0, value=0.45, step=0.01) |
|
multi_label = st.selectbox('Multiple labels per box', (False, True)) |
|
agnostic = st.selectbox('Class-agnostic', (False, True)) |
|
amp = st.selectbox('Automatic Mixed Precision inference', (False, True)) |
|
max_det = st.number_input(label='Maximum number of detections per image', min_value=None, max_value=None, value=1000, step=1) |
|
clicked = st.button('Execute') |
|
|
|
if clicked and uploaded_file is not None: |
|
with st.spinner('Loading the image...'): |
|
|
|
input_image = Image.open(uploaded_file) |
|
if input_image.format == 'WEBP': |
|
st.info("WEBP形式の画像をPNG形式に変換しています...") |
|
buffer = io.BytesIO() |
|
input_image.save(buffer, format="PNG") |
|
buffer.seek(0) |
|
input_image = Image.open(buffer) |
|
|
|
with st.spinner('Loading the model...'): |
|
model = read_pretrain(pretrained) |
|
with st.spinner('Updating configuration...'): |
|
model.conf = float(conf) |
|
model.max_det = int(max_det) |
|
model.iou = float(iou) |
|
model.agnostic = agnostic |
|
model.multi_label = multi_label |
|
model.amp = amp |
|
with st.spinner('Predicting...'): |
|
results = model(input_image, size=int(imgsz)) |
|
for img in results.render(): |
|
st.image(img) |
|
st.write(results.pandas().xyxy[0]) |
|
else: |
|
if not uploaded_file: |
|
st.warning("画像ファイルをアップロードしてください。") |
|
|