File size: 2,924 Bytes
e9e5323
c1115fd
c9ce858
 
 
5d2367f
2ccd650
94239c6
 
 
 
 
 
 
 
 
 
97b86b5
688e29a
 
dc89f9d
 
c9ce858
 
5d2367f
 
 
 
dc89f9d
5d2367f
 
 
 
 
0492776
5e3b678
 
e46d69b
 
 
dc89f9d
 
 
 
0492776
5d2367f
dc89f9d
5d2367f
 
 
 
 
 
 
b2bee1d
 
0492776
 
 
 
 
 
dc89f9d
0492776
 
c9ce858
 
6b95018
dc89f9d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
from PIL import Image
import streamlit as st
from pathlib import Path
import io  # メモリ内バッファを扱うため

def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
    # Check file(s) for acceptable suffix
    if file and suffix:
        if isinstance(suffix, str):
            suffix = [suffix]
        for f in file if isinstance(file, (list, tuple)) else [file]:
            s = Path(f).suffix.lower()  # file suffix
            if len(s):
                assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"


def read_pretrain(path):
    return torch.hub.load('ultralytics/yolov5', 'custom', path=path)


st.title("Hololive Waifu Classification")

uploaded_file = st.file_uploader(
    "画像ファイルをアップロードしてください (対応形式: JPG, JPEG, PNG, WEBP)", 
    type=["jpg", "jpeg", "png", "webp"]
)

pretrained = st.selectbox(
    'Select pre-trained', 
    ('2022.11.04-YOLOv5x6_1280-Hololive_Waifu_Classification.pt', 
     '2022.11.01-YOLOv5x6_1280-Hololive_Waifu_Classification.pt')
)
imgsz = st.number_input(label='Image Size', min_value=None, max_value=None, value=1280, step=1)
conf = st.slider(label='Confidence threshold', min_value=0.0, max_value=1.0, value=0.25, step=0.01)
iou = st.slider(label='IoU threshold', min_value=0.0, max_value=1.0, value=0.45, step=0.01)
multi_label = st.selectbox('Multiple labels per box', (False, True))
agnostic = st.selectbox('Class-agnostic', (False, True))
amp = st.selectbox('Automatic Mixed Precision inference', (False, True))
max_det = st.number_input(label='Maximum number of detections per image', min_value=None, max_value=None, value=1000, step=1)
clicked = st.button('Execute')

if clicked and uploaded_file is not None:
    with st.spinner('Loading the image...'):
        # ファイルを開き、webp形式の場合はPNGに変換
        input_image = Image.open(uploaded_file)
        if input_image.format == 'WEBP':
            st.info("WEBP形式の画像をPNG形式に変換しています...")
            buffer = io.BytesIO()  # メモリ内バッファを作成
            input_image.save(buffer, format="PNG")  # PNG形式で保存
            buffer.seek(0)
            input_image = Image.open(buffer)  # 再度開き直す

    with st.spinner('Loading the model...'):
        model = read_pretrain(pretrained)
    with st.spinner('Updating configuration...'):
        model.conf = float(conf)
        model.max_det = int(max_det)
        model.iou = float(iou)
        model.agnostic = agnostic
        model.multi_label = multi_label
        model.amp = amp
    with st.spinner('Predicting...'):
        results = model(input_image, size=int(imgsz))
    for img in results.render():
        st.image(img)
    st.write(results.pandas().xyxy[0])
else:
    if not uploaded_file:
        st.warning("画像ファイルをアップロードしてください。")