pair-images-aug-test / test_app.py
yeq6x's picture
r
095a64f
import gradio as gr
import os
from PIL import Image
import tempfile
from dataset_aug import process_images
import convert_source_to_sketch # スケッチ変換用のモジュールをインポート
import random
def process_multiple_images(
source_images,
target_images,
output_size,
num_copies,
is_flip,
rotation_range,
min_scale,
max_scale,
source_is_avg_color_fill,
source_is_edge_mode_fill,
target_is_avg_color_fill,
target_is_edge_mode_fill,
expand_to_long_side,
progress=gr.Progress()
):
result_source_images = []
result_target_images = []
progress(0, desc="処理を開始します...")
# 各画像ペアに対して処理を実行
total_pairs = len(source_images)
for idx, (source_path, target_path) in enumerate(zip(source_images, target_images), 1):
progress(idx/total_pairs, desc=f"画像ペア {idx}/{total_pairs} を処理中...")
# PILイメージとして読み込み
source_img = Image.open(source_path.name)
target_img = Image.open(target_path.name)
# 拡張処理を実行し、PILイメージのリストを取得
aug_sources, aug_targets = process_images(
source_img,
target_img,
num_copies=num_copies,
output_size=(output_size, output_size),
is_flip=is_flip,
rotation_range=rotation_range,
min_scale=min_scale,
max_scale=max_scale,
source_is_avg_color_fill=source_is_avg_color_fill,
source_is_edge_mode_fill=source_is_edge_mode_fill,
target_is_avg_color_fill=target_is_avg_color_fill,
target_is_edge_mode_fill=target_is_edge_mode_fill,
expand_to_long_side=expand_to_long_side
)
# 生成された画像を収集
result_source_images.extend(aug_sources)
result_target_images.extend(aug_targets)
progress(1, desc="処理が完了しました")
return result_source_images, result_target_images, "処理が完了しました!"
def update_source_preview(source_files):
preview_images = []
if source_files:
for source in source_files:
preview_images.append(source.name)
return preview_images
def update_target_preview(target_files):
preview_images = []
if target_files:
for target in target_files:
preview_images.append(target.name)
return preview_images
def convert_to_sketch(source_files):
"""sourceをスケッチに変換"""
converted_images = []
if source_files:
# 一時ディレクトリを作成(グローバルに保持)
temp_dir = tempfile.mkdtemp()
try:
for source in source_files:
# スケッチ変換処理
image = Image.open(source.name)
sketch = convert_source_to_sketch.convert_pil_to_sketch(image)
# 一時ファイルとして保存
temp_path = os.path.join(temp_dir, os.path.basename(source.name))
sketch.save(temp_path)
converted_images.append(temp_path)
except Exception as e:
print(f"Error during conversion: {e}")
# エラー時にも一時ディレクトリを削除
if os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir)
return []
return converted_images
# アプリケーション終了時のクリーンアップ処理を修正
def cleanup_temp_files():
"""一時ファイルをクリーンアップ"""
temp_root = tempfile.gettempdir()
for item in os.listdir(temp_root):
if item.startswith('tmp'):
item_path = os.path.join(temp_root, item)
try:
if os.path.isdir(item_path):
# ディレクトリ内の画像ファイルをチェック
for root, dirs, files in os.walk(item_path):
for file in files:
if file.endswith(('.jpg', '.png')):
file_path = os.path.join(root, file)
try:
with Image.open(file_path) as img:
img.verify() # 画像ファイルの整合性チェック
except Exception as e:
print(f"Corrupted image found: {file_path} - {e}")
import shutil
shutil.rmtree(item_path)
except Exception as e:
print(f"Error cleaning up {item_path}: {e}")
def randomize_params():
"""パラメータをランダムに設定"""
return (
random.choice([512, 768, 1024, 1536, 2048]), # output_size
random.randint(1, 5), # num_copies
random.choice([True, False]), # is_flip
random.randint(0, 180), # rotation_range
round(random.uniform(0.1, 1.0), 1), # min_scale
round(random.uniform(1.0, 2.0), 1), # max_scale
random.choice([True, False]), # source_is_avg_color_fill
random.choice([True, False]), # source_is_edge_mode_fill
random.choice([True, False]), # target_is_avg_color_fill
random.choice([True, False]), # target_is_edge_mode_fill
random.choice([True, False]) # expand_to_long_side
)
def reset_params():
"""パラメータを初期設定に戻す"""
return (
1024, # output_size
1, # num_copies
True, # is_flip
0, # rotation_range
1.0, # min_scale
1.0, # max_scale
True, # source_is_avg_color_fill
False, # source_is_edge_mode_fill
False, # target_is_avg_color_fill
False, # target_is_edge_mode_fill
False # expand_to_long_side
)
def test_process_image_pair_with_expand_to_long_side():
"""長辺拡張オプションのテスト"""
source_image = Image.new('RGB', (800, 400), color='white')
target_image = Image.new('RGB', (800, 400), color='white')
# process_imagesを使用するように修正
aug_sources, aug_targets = process_images(
source_image,
target_image,
num_copies=1,
output_size=(512, 512),
is_flip=False,
rotation_range=0,
min_scale=1.0,
max_scale=1.0,
source_is_avg_color_fill=True,
source_is_edge_mode_fill=False,
target_is_avg_color_fill=True,
target_is_edge_mode_fill=False,
expand_to_long_side=True
)
result_source = aug_sources[0]
result_target = aug_targets[0]
# 結果が正方形であることを確認
assert result_source.size[0] == result_source.size[1]
assert result_target.size[0] == result_target.size[1]
# 出力サイズが指定通りであることを確認
assert result_source.size == (512, 512)
assert result_target.size == (512, 512)
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("# Pair Image Augmentation Test")
gr.Markdown("Code : https://github.com/Yeq6X/pair-images-aug")
with gr.Row():
# 左側のカラム(Source画像とパラメータ)
with gr.Column():
with gr.Row():
# Source画像
source_files = gr.File(
label="Source画像を選択",
file_count="multiple",
file_types=["image"],
height=150
)
# Target画像
target_files = gr.File(
label="Target画像を選択",
file_count="multiple",
file_types=["image"],
height=150
)
# サンプル画像の追加
gr.Examples(
examples=[
[["samples/source/sample1.png", "samples/source/sample2.png"],
["samples/target/sample1.png", "samples/target/sample2.png"]],
],
inputs=[source_files, target_files],
label="サンプル画像セット",
examples_per_page=5
)
source_preview = gr.Gallery(
label="Source画像プレビュー",
show_label=True,
object_fit="contain",
columns=4,
)
with gr.Row():
gr.Markdown("### scribble_xdogで変換")
convert_src_to_tgt_btn = gr.Button("↓", variant="primary", size="sm")
convert_tgt_to_src_btn = gr.Button("↑", variant="primary", size="sm")
gr.Markdown("")
target_preview = gr.Gallery(
label="Target画像プレビュー",
show_label=True,
object_fit="contain",
columns=4,
)
# 右側のカラム(Target画像と出力)
with gr.Column():
# パラメータ設定部分
with gr.Accordion("パラメータ設定", open=True):
# パラメータ操作ボタン
with gr.Row():
randomize_btn = gr.Button("🎲 ランダム設定", variant="secondary")
reset_btn = gr.Button("↺ 初期設定に戻す", variant="secondary")
with gr.Row():
with gr.Column():
output_size = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=256,
label="出力画像サイズ"
)
num_copies = gr.Slider(
minimum=1,
maximum=5,
value=1,
step=1,
label="リピート回数"
)
is_flip = gr.Checkbox(
label="ランダムフリップを適用",
value=True
)
expand_to_long_side = gr.Checkbox(
label="長辺に合わせて拡張する",
value=False
)
rotation_range = gr.Slider(
minimum=0,
maximum=180,
value=0,
step=1,
label="回転角度の範囲"
)
with gr.Column():
min_scale = gr.Slider(
minimum=0.1,
maximum=1.0,
value=1.0,
step=0.1,
label="最小スケール"
)
max_scale = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.0,
step=0.1,
label="最大スケール"
)
with gr.Row():
with gr.Column():
source_is_edge_mode_fill = gr.Checkbox(
label="Source: 外周の最頻色で埋める",
value=False
)
source_is_avg_color_fill = gr.Checkbox(
label="Source: 画像の平均色で埋める",
value=True
)
with gr.Column():
target_is_edge_mode_fill = gr.Checkbox(
label="Target: 外周の最頻色で埋める",
value=False
)
target_is_avg_color_fill = gr.Checkbox(
label="Target: 画像の平均色で埋める",
value=False
)
process_btn = gr.Button("処理開始", variant="primary")
# ログ表示用のテキストボックスを追加
log_output = gr.Textbox(
label="処理ログ",
value="",
lines=1,
max_lines=1,
interactive=False,
visible=False
)
# 結果表示
result_source_gallery = gr.Gallery(
label="生成結果 (Source)",
show_label=True,
object_fit="contain",
columns=4,
type="pil"
)
result_target_gallery = gr.Gallery(
label="生成結果 (Target)",
show_label=True,
object_fit="contain",
columns=4,
type="pil"
)
# イベントハンドラ
source_files.change(
fn=update_source_preview,
inputs=[source_files],
outputs=source_preview
)
target_files.change(
fn=update_target_preview,
inputs=[target_files],
outputs=target_preview
)
convert_src_to_tgt_btn.click(
fn=convert_to_sketch,
inputs=[source_files],
outputs=[target_files]
)
convert_tgt_to_src_btn.click(
fn=convert_to_sketch,
inputs=[target_files],
outputs=[source_files]
)
param_outputs = [
output_size,
num_copies,
is_flip,
rotation_range,
min_scale,
max_scale,
source_is_avg_color_fill,
source_is_edge_mode_fill,
target_is_avg_color_fill,
target_is_edge_mode_fill,
expand_to_long_side
]
randomize_btn.click(
fn=randomize_params,
inputs=[],
outputs=param_outputs
)
reset_btn.click(
fn=reset_params,
inputs=[],
outputs=param_outputs
)
process_btn.click(
fn=process_multiple_images,
inputs=[
source_files,
target_files,
output_size,
num_copies,
is_flip,
rotation_range,
min_scale,
max_scale,
source_is_avg_color_fill,
source_is_edge_mode_fill,
target_is_avg_color_fill,
target_is_edge_mode_fill,
expand_to_long_side
],
outputs=[result_source_gallery, result_target_gallery, log_output]
)
if __name__ == "__main__":
try:
demo.launch(
# server_name="0.0.0.0",
# server_port=8000,
debug=True
)
finally:
cleanup_temp_files() # アプリケーション終了時にクリーンアップ