KeyframesAI / evaluate.py
acmyu's picture
get metrics for user edits
c54f540
from main import extract_frames, run_eval #run
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch
import torchvision.transforms as transforms
import lpips
from pytorch_fid.fid_score import calculate_fid_given_paths
from cdfvd import fvd
import os
import json
import cv2
from huggingface_hub import snapshot_download
outdir = 'outputs/' #'/data/out/'
# Convert PIL to numpy
def pil_to_np(img):
return np.array(img).astype(np.float32) / 255.0
def save_mp4(images, name):
width, height = images[0].size
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4
video = cv2.VideoWriter(name, fourcc, 12, (width, height))
for image in images:
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
video.write(img)
video.release()
# SSIM
def compute_ssim(img1, img2):
img1_np = pil_to_np(img1)
img2_np = pil_to_np(img2)
h, w = img1_np.shape[:2]
min_dim = min(h, w)
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd
return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0)
# PSNR
def compute_psnr(img1, img2):
img1_np = pil_to_np(img1)
img2_np = pil_to_np(img2)
return psnr(img1_np, img2_np, data_range=1.0)
# LPIPS
lpips_model = lpips.LPIPS(net='alex')
lpips_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
def compute_lpips(img1, img2):
img1_tensor = lpips_transform(img1).unsqueeze(0)
img2_tensor = lpips_transform(img2).unsqueeze(0)
return lpips_model(img1_tensor, img2_tensor).item()
def trans(x):
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
# permute BTCHW -> BCTHW
x = x.permute(0, 2, 1, 3, 4)
return x
def compute_fvd(item, gt_imgs, results):
os.makedirs('temp/gt', exist_ok=True)
os.makedirs('temp/result', exist_ok=True)
save_mp4(gt_imgs, "temp/gt/gt.mp4")
save_mp4(results, "temp/result/result.mp4")
evaluator = fvd.cdfvd('i3d', ckpt_path=None, device='cuda', n_real=1, n_fake=1)
evaluator.compute_real_stats(evaluator.load_videos('temp/gt', data_type='video_folder'))
evaluator.compute_fake_stats(evaluator.load_videos('temp/result', data_type='video_folder'))
score = evaluator.compute_fvd_from_stats()
evaluator.offload_model_to_cpu()
print(score)
return score
def compute_fidx(item, gt_imgs, results):
os.makedirs('temp/'+item+'_gt', exist_ok=True)
os.makedirs('temp/'+item, exist_ok=True)
c = 0
for img in gt_imgs:
img.save('temp/'+item+'_gt/'+str(c)+'.png')
c = c+1
c = 0
for img in gt_imgs:
img.save('temp/'+item+'/'+str(c)+'.png')
c = c+1
fid = calculate_fid_given_paths(['temp/'+item+'_gt', 'temp/'+item], batch_size=8, device='cuda', dims=2048)
return fid
# FID: Save images to temp folders for FID calculation
def compute_fid(img1, img2):
os.makedirs('temp/img1', exist_ok=True)
os.makedirs('temp/img2', exist_ok=True)
img1.save('temp/img1/0.png')
img2.save('temp/img2/0.png')
fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cuda', dims=2048)
return fid
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
images = []
results = []
results_base = []
gt_frames = []
max_frame_count = 200
if os.path.isdir(outdir+item):
for filename in os.listdir(outdir+item):
img = Image.open(outdir+item+'/'+filename)
if filename.startswith('result_'):
results.append(img)
elif filename.startswith('base_'):
results_base.append(img)
elif filename.startswith('frame_'):
gt_frames.append(img)
#results = results[:max_frame_count]
#results_base = results_base[:max_frame_count]
#gt_frames = gt_frames[:max_frame_count]
else:
if not isinstance(image_paths[0], str):
images = image_paths
else:
for path in image_paths:
print(path)
img = Image.open(path)
images.append([img])
gt_frames = extract_frames(video_path, fps)
gt_frames = gt_frames[:max_frame_count]
for f in gt_frames:
f.thumbnail((512,512))
#results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False)
os.makedirs(outdir+item, exist_ok=True)
for i, frame in enumerate(gt_frames):
frame.save(outdir+item+"/frame_"+str(i)+".png")
for i, result in enumerate(results):
result.save(outdir+item+"/result_"+str(i)+".png")
for i, result in enumerate(results_base):
result.save(outdir+item+"/base_"+str(i)+".png")
ssim = []
psnr = []
lpips = []
fid = []
ssim2 = []
psnr2 = []
lpips2 = []
fid2 = []
c = 0
#print(len(gt_frames), len(results), len(results_base))
for gt, result, base in zip(gt_frames, results, results_base):
ssim.append(float(compute_ssim(gt, result)))
psnr.append(float(compute_psnr(gt, result)))
lpips.append(float(compute_lpips(gt, result)))
ssim2.append(float(compute_ssim(gt, base)))
psnr2.append(float(compute_psnr(gt, base)))
lpips2.append(float(compute_lpips(gt, base)))
if c<50:
print(c)
fid.append(float(compute_fid(gt, result)))
fid2.append(float(compute_fid(gt, base)))
c = c+1
#fvd = float(compute_fvd(item, gt_frames, results))
#fvd2 = float(compute_fvd(item, gt_frames, results_base))
print("SSIM:", sum(ssim)/len(ssim))
print("PSNR:", sum(psnr)/len(psnr))
print("LPIPS:", sum(lpips)/len(lpips))
print("FID:", sum(fid)/len(fid))
#print("FVD:", fvd)
print('baseline:')
print("SSIM:", sum(ssim2)/len(ssim2))
print("PSNR:", sum(psnr2)/len(psnr2))
print("LPIPS:", sum(lpips2)/len(lpips2))
print("FID:", sum(fid2)/len(fid2))
#print("FVD:", fvd2)
metrics = {}
metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)}
metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid}
#metrics[item]['ft']['fvd'] = fvd
metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2}
metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2}
metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
#metrics[item]['base']['fvd'] = fvd2
#print(metrics)
return metrics[item]
def get_files(directory_path):
"""
Returns a list of all files in the specified directory.
"""
files = []
for entry in os.listdir(directory_path):
full_path = os.path.join(directory_path, entry)
if os.path.isfile(full_path):
files.append(entry)
return files
def run_evaluate():
print("run_evaluate")
snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset")
with open('/data/metrics.json', 'r') as file:
metrics = json.load(file)
items = os.listdir('test')
items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)]
print(items)
#items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
for item in items:
if item in metrics:
continue
print(item)
try:
files = get_files('test/'+item)
images = list(filter(lambda x: not x.endswith('.mp4'), files))
images = ['test/'+item+'/'+img for img in images]
videos = [x for x in files if x.endswith('.mp4')]
print(images, videos)
if len(videos) == 1:
metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0])
#get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
with open('/data/metrics.json', "w", encoding="utf-8") as json_file:
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
else:
print('Error: mp4 not found')
except Exception as e:
print("Error", item, e)
ssim = []
psnr = []
lpips = []
fid = []
ssim2 = []
psnr2 = []
lpips2 = []
fid2 = []
for item in metrics.keys():
ssim.append(metrics[item]['ft']['ssim']['avg'])
psnr.append(metrics[item]['ft']['psnr']['avg'])
lpips.append(metrics[item]['ft']['lpips']['avg'])
fid.append(metrics[item]['ft']['fid']['avg'])
ssim2.append(metrics[item]['base']['ssim']['avg'])
psnr2.append(metrics[item]['base']['psnr']['avg'])
lpips2.append(metrics[item]['base']['lpips']['avg'])
fid2.append(metrics[item]['base']['fid']['avg'])
print(item)
print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg'])
print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg'])
print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg'])
print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg'])
print('Results:')
print("SSIM:", sum(ssim)/len(ssim))
print("PSNR:", sum(psnr)/len(psnr))
print("LPIPS:", sum(lpips)/len(lpips))
print("FID:", sum(fid)/len(fid))
print('baseline:')
print("SSIM:", sum(ssim2)/len(ssim2))
print("PSNR:", sum(psnr2)/len(psnr2))
print("LPIPS:", sum(lpips2)/len(lpips2))
print("FID:", sum(fid2)/len(fid2))