Spaces:
Paused
Paused
from main import extract_frames, run_eval #run | |
from PIL import Image | |
import numpy as np | |
from skimage.metrics import structural_similarity as ssim | |
from skimage.metrics import peak_signal_noise_ratio as psnr | |
import torch | |
import torchvision.transforms as transforms | |
import lpips | |
from pytorch_fid.fid_score import calculate_fid_given_paths | |
from cdfvd import fvd | |
import os | |
import json | |
import cv2 | |
from huggingface_hub import snapshot_download | |
outdir = 'outputs/' #'/data/out/' | |
# Convert PIL to numpy | |
def pil_to_np(img): | |
return np.array(img).astype(np.float32) / 255.0 | |
def save_mp4(images, name): | |
width, height = images[0].size | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 | |
video = cv2.VideoWriter(name, fourcc, 12, (width, height)) | |
for image in images: | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
video.write(img) | |
video.release() | |
# SSIM | |
def compute_ssim(img1, img2): | |
img1_np = pil_to_np(img1) | |
img2_np = pil_to_np(img2) | |
h, w = img1_np.shape[:2] | |
min_dim = min(h, w) | |
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd | |
return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0) | |
# PSNR | |
def compute_psnr(img1, img2): | |
img1_np = pil_to_np(img1) | |
img2_np = pil_to_np(img2) | |
return psnr(img1_np, img2_np, data_range=1.0) | |
# LPIPS | |
lpips_model = lpips.LPIPS(net='alex') | |
lpips_transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5]*3, [0.5]*3) | |
]) | |
def compute_lpips(img1, img2): | |
img1_tensor = lpips_transform(img1).unsqueeze(0) | |
img2_tensor = lpips_transform(img2).unsqueeze(0) | |
return lpips_model(img1_tensor, img2_tensor).item() | |
def trans(x): | |
# if greyscale images add channel | |
if x.shape[-3] == 1: | |
x = x.repeat(1, 1, 3, 1, 1) | |
# permute BTCHW -> BCTHW | |
x = x.permute(0, 2, 1, 3, 4) | |
return x | |
def compute_fvd(item, gt_imgs, results): | |
os.makedirs('temp/gt', exist_ok=True) | |
os.makedirs('temp/result', exist_ok=True) | |
save_mp4(gt_imgs, "temp/gt/gt.mp4") | |
save_mp4(results, "temp/result/result.mp4") | |
evaluator = fvd.cdfvd('i3d', ckpt_path=None, device='cuda', n_real=1, n_fake=1) | |
evaluator.compute_real_stats(evaluator.load_videos('temp/gt', data_type='video_folder')) | |
evaluator.compute_fake_stats(evaluator.load_videos('temp/result', data_type='video_folder')) | |
score = evaluator.compute_fvd_from_stats() | |
evaluator.offload_model_to_cpu() | |
print(score) | |
return score | |
def compute_fidx(item, gt_imgs, results): | |
os.makedirs('temp/'+item+'_gt', exist_ok=True) | |
os.makedirs('temp/'+item, exist_ok=True) | |
c = 0 | |
for img in gt_imgs: | |
img.save('temp/'+item+'_gt/'+str(c)+'.png') | |
c = c+1 | |
c = 0 | |
for img in gt_imgs: | |
img.save('temp/'+item+'/'+str(c)+'.png') | |
c = c+1 | |
fid = calculate_fid_given_paths(['temp/'+item+'_gt', 'temp/'+item], batch_size=8, device='cuda', dims=2048) | |
return fid | |
# FID: Save images to temp folders for FID calculation | |
def compute_fid(img1, img2): | |
os.makedirs('temp/img1', exist_ok=True) | |
os.makedirs('temp/img2', exist_ok=True) | |
img1.save('temp/img1/0.png') | |
img2.save('temp/img2/0.png') | |
fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cuda', dims=2048) | |
return fid | |
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False): | |
images = [] | |
results = [] | |
results_base = [] | |
gt_frames = [] | |
max_frame_count = 200 | |
if os.path.isdir(outdir+item): | |
for filename in os.listdir(outdir+item): | |
img = Image.open(outdir+item+'/'+filename) | |
if filename.startswith('result_'): | |
results.append(img) | |
elif filename.startswith('base_'): | |
results_base.append(img) | |
elif filename.startswith('frame_'): | |
gt_frames.append(img) | |
#results = results[:max_frame_count] | |
#results_base = results_base[:max_frame_count] | |
#gt_frames = gt_frames[:max_frame_count] | |
else: | |
if not isinstance(image_paths[0], str): | |
images = image_paths | |
else: | |
for path in image_paths: | |
print(path) | |
img = Image.open(path) | |
images.append([img]) | |
gt_frames = extract_frames(video_path, fps) | |
gt_frames = gt_frames[:max_frame_count] | |
for f in gt_frames: | |
f.thumbnail((512,512)) | |
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True) | |
results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False) | |
os.makedirs(outdir+item, exist_ok=True) | |
for i, frame in enumerate(gt_frames): | |
frame.save(outdir+item+"/frame_"+str(i)+".png") | |
for i, result in enumerate(results): | |
result.save(outdir+item+"/result_"+str(i)+".png") | |
for i, result in enumerate(results_base): | |
result.save(outdir+item+"/base_"+str(i)+".png") | |
ssim = [] | |
psnr = [] | |
lpips = [] | |
fid = [] | |
ssim2 = [] | |
psnr2 = [] | |
lpips2 = [] | |
fid2 = [] | |
c = 0 | |
#print(len(gt_frames), len(results), len(results_base)) | |
for gt, result, base in zip(gt_frames, results, results_base): | |
ssim.append(float(compute_ssim(gt, result))) | |
psnr.append(float(compute_psnr(gt, result))) | |
lpips.append(float(compute_lpips(gt, result))) | |
ssim2.append(float(compute_ssim(gt, base))) | |
psnr2.append(float(compute_psnr(gt, base))) | |
lpips2.append(float(compute_lpips(gt, base))) | |
if c<50: | |
print(c) | |
fid.append(float(compute_fid(gt, result))) | |
fid2.append(float(compute_fid(gt, base))) | |
c = c+1 | |
#fvd = float(compute_fvd(item, gt_frames, results)) | |
#fvd2 = float(compute_fvd(item, gt_frames, results_base)) | |
print("SSIM:", sum(ssim)/len(ssim)) | |
print("PSNR:", sum(psnr)/len(psnr)) | |
print("LPIPS:", sum(lpips)/len(lpips)) | |
print("FID:", sum(fid)/len(fid)) | |
#print("FVD:", fvd) | |
print('baseline:') | |
print("SSIM:", sum(ssim2)/len(ssim2)) | |
print("PSNR:", sum(psnr2)/len(psnr2)) | |
print("LPIPS:", sum(lpips2)/len(lpips2)) | |
print("FID:", sum(fid2)/len(fid2)) | |
#print("FVD:", fvd2) | |
metrics = {} | |
metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)} | |
metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim} | |
metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr} | |
metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips} | |
metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid} | |
#metrics[item]['ft']['fvd'] = fvd | |
metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2} | |
metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2} | |
metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2} | |
metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2} | |
#metrics[item]['base']['fvd'] = fvd2 | |
#print(metrics) | |
return metrics[item] | |
def get_files(directory_path): | |
""" | |
Returns a list of all files in the specified directory. | |
""" | |
files = [] | |
for entry in os.listdir(directory_path): | |
full_path = os.path.join(directory_path, entry) | |
if os.path.isfile(full_path): | |
files.append(entry) | |
return files | |
def run_evaluate(): | |
print("run_evaluate") | |
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset") | |
with open('/data/metrics.json', 'r') as file: | |
metrics = json.load(file) | |
items = os.listdir('test') | |
items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)] | |
print(items) | |
#items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody'] | |
for item in items: | |
if item in metrics: | |
continue | |
print(item) | |
try: | |
files = get_files('test/'+item) | |
images = list(filter(lambda x: not x.endswith('.mp4'), files)) | |
images = ['test/'+item+'/'+img for img in images] | |
videos = [x for x in files if x.endswith('.mp4')] | |
print(images, videos) | |
if len(videos) == 1: | |
metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0]) | |
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4') | |
with open('/data/metrics.json', "w", encoding="utf-8") as json_file: | |
json.dump(metrics, json_file, ensure_ascii=False, indent=4) | |
else: | |
print('Error: mp4 not found') | |
except Exception as e: | |
print("Error", item, e) | |
ssim = [] | |
psnr = [] | |
lpips = [] | |
fid = [] | |
ssim2 = [] | |
psnr2 = [] | |
lpips2 = [] | |
fid2 = [] | |
for item in metrics.keys(): | |
ssim.append(metrics[item]['ft']['ssim']['avg']) | |
psnr.append(metrics[item]['ft']['psnr']['avg']) | |
lpips.append(metrics[item]['ft']['lpips']['avg']) | |
fid.append(metrics[item]['ft']['fid']['avg']) | |
ssim2.append(metrics[item]['base']['ssim']['avg']) | |
psnr2.append(metrics[item]['base']['psnr']['avg']) | |
lpips2.append(metrics[item]['base']['lpips']['avg']) | |
fid2.append(metrics[item]['base']['fid']['avg']) | |
print(item) | |
print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg']) | |
print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg']) | |
print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg']) | |
print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg']) | |
print('Results:') | |
print("SSIM:", sum(ssim)/len(ssim)) | |
print("PSNR:", sum(psnr)/len(psnr)) | |
print("LPIPS:", sum(lpips)/len(lpips)) | |
print("FID:", sum(fid)/len(fid)) | |
print('baseline:') | |
print("SSIM:", sum(ssim2)/len(ssim2)) | |
print("PSNR:", sum(psnr2)/len(psnr2)) | |
print("LPIPS:", sum(lpips2)/len(lpips2)) | |
print("FID:", sum(fid2)/len(fid2)) | |