import os import torch from torchvision import transforms from PIL import Image import numpy as np import gradio as gr from watermark_remover import WatermarkRemover # デバイス:CPU専用 device = torch.device("cpu") # モデル読み込み model = WatermarkRemover().to(device) model.load_state_dict(torch.load("model.pth", map_location=device)) model.eval() # 画像変換用トランスフォーム transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) # 推論関数 def remove_watermark(image: Image.Image, num_samples: int = 10) -> Image.Image: original_size = image.size image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) outputs = [] with torch.no_grad(): for _ in range(num_samples): output_tensor = model(input_tensor) outputs.append(output_tensor) # 複数回の出力を平均化 avg_output = torch.stack(outputs, dim=0).mean(dim=0) predicted_image = avg_output.squeeze(0).cpu().permute(1, 2, 0).clamp(0, 1).numpy() predicted_pil = Image.fromarray((predicted_image * 255).astype(np.uint8)) predicted_pil = predicted_pil.resize(original_size, Image.Resampling.LANCZOS) return predicted_pil # Gradio UI app = gr.Interface( fn=remove_watermark, inputs=gr.Image(type="pil", label="ウォーターマーク付き画像をアップロード"), outputs=gr.Image(type="pil", label="ウォーターマーク除去後の画像"), title="ウォーターマーク除去AI (CPU対応)", description="このアプリは、FODUU AIが開発したモデルを使用して画像からウォーターマークを除去します。※ 処理には数秒かかる場合があります。" ) # アプリ実行 if __name__ == "__main__": app.launch()