Img_test / app.py
Kingoteam's picture
Update app.py
d053cfa verified
raw
history blame contribute delete
772 Bytes
import torch
from torchvision.utils import save_image
import os
import numpy as np
import tempfile
from model import Generator # فرضاً مدل StyleGAN2 یا مشابه را import می‌کنیم
# لود مدل pretrained (مثلاً unconditional anime face)
device = "cpu"
G = Generator() # مدل از قبل train شده
G.load_state_dict(torch.load("unconditional_anime_face.pth", map_location=device))
G.to(device)
G.eval()
# تولید تصویر بدون پرامپت
z = torch.randn(1, 512, device=device) # latent vector
with torch.no_grad():
img = G(z)
# ذخیره تصویر
tmpdir = tempfile.mkdtemp()
file_path = os.path.join(tmpdir, "output.png")
save_image((img + 1) / 2, file_path) # normalize [-1,1] -> [0,1]
print("Saved:", file_path)