File size: 2,924 Bytes
e9e5323 c1115fd c9ce858 5d2367f 2ccd650 94239c6 97b86b5 688e29a dc89f9d c9ce858 5d2367f dc89f9d 5d2367f 0492776 5e3b678 e46d69b dc89f9d 0492776 5d2367f dc89f9d 5d2367f b2bee1d 0492776 dc89f9d 0492776 c9ce858 6b95018 dc89f9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
from PIL import Image
import streamlit as st
from pathlib import Path
import io # メモリ内バッファを扱うため
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffix
if file and suffix:
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
def read_pretrain(path):
return torch.hub.load('ultralytics/yolov5', 'custom', path=path)
st.title("Hololive Waifu Classification")
uploaded_file = st.file_uploader(
"画像ファイルをアップロードしてください (対応形式: JPG, JPEG, PNG, WEBP)",
type=["jpg", "jpeg", "png", "webp"]
)
pretrained = st.selectbox(
'Select pre-trained',
('2022.11.04-YOLOv5x6_1280-Hololive_Waifu_Classification.pt',
'2022.11.01-YOLOv5x6_1280-Hololive_Waifu_Classification.pt')
)
imgsz = st.number_input(label='Image Size', min_value=None, max_value=None, value=1280, step=1)
conf = st.slider(label='Confidence threshold', min_value=0.0, max_value=1.0, value=0.25, step=0.01)
iou = st.slider(label='IoU threshold', min_value=0.0, max_value=1.0, value=0.45, step=0.01)
multi_label = st.selectbox('Multiple labels per box', (False, True))
agnostic = st.selectbox('Class-agnostic', (False, True))
amp = st.selectbox('Automatic Mixed Precision inference', (False, True))
max_det = st.number_input(label='Maximum number of detections per image', min_value=None, max_value=None, value=1000, step=1)
clicked = st.button('Execute')
if clicked and uploaded_file is not None:
with st.spinner('Loading the image...'):
# ファイルを開き、webp形式の場合はPNGに変換
input_image = Image.open(uploaded_file)
if input_image.format == 'WEBP':
st.info("WEBP形式の画像をPNG形式に変換しています...")
buffer = io.BytesIO() # メモリ内バッファを作成
input_image.save(buffer, format="PNG") # PNG形式で保存
buffer.seek(0)
input_image = Image.open(buffer) # 再度開き直す
with st.spinner('Loading the model...'):
model = read_pretrain(pretrained)
with st.spinner('Updating configuration...'):
model.conf = float(conf)
model.max_det = int(max_det)
model.iou = float(iou)
model.agnostic = agnostic
model.multi_label = multi_label
model.amp = amp
with st.spinner('Predicting...'):
results = model(input_image, size=int(imgsz))
for img in results.render():
st.image(img)
st.write(results.pandas().xyxy[0])
else:
if not uploaded_file:
st.warning("画像ファイルをアップロードしてください。")
|