from main import extract_frames, run_eval #run from PIL import Image import numpy as np from skimage.metrics import structural_similarity as ssim from skimage.metrics import peak_signal_noise_ratio as psnr import torch import torchvision.transforms as transforms import lpips from pytorch_fid.fid_score import calculate_fid_given_paths from cdfvd import fvd import os import json import cv2 from huggingface_hub import snapshot_download # Convert PIL to numpy def pil_to_np(img): return np.array(img).astype(np.float32) / 255.0 def save_mp4(images, name): width, height = images[0].size fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 video = cv2.VideoWriter(name, fourcc, 12, (width, height)) for image in images: img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) video.write(img) video.release() # SSIM def compute_ssim(img1, img2): img1_np = pil_to_np(img1) img2_np = pil_to_np(img2) h, w = img1_np.shape[:2] min_dim = min(h, w) win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0) # PSNR def compute_psnr(img1, img2): img1_np = pil_to_np(img1) img2_np = pil_to_np(img2) return psnr(img1_np, img2_np, data_range=1.0) # LPIPS lpips_model = lpips.LPIPS(net='alex') lpips_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) def compute_lpips(img1, img2): img1_tensor = lpips_transform(img1).unsqueeze(0) img2_tensor = lpips_transform(img2).unsqueeze(0) return lpips_model(img1_tensor, img2_tensor).item() def trans(x): # if greyscale images add channel if x.shape[-3] == 1: x = x.repeat(1, 1, 3, 1, 1) # permute BTCHW -> BCTHW x = x.permute(0, 2, 1, 3, 4) return x def compute_fvd(item, gt_imgs, results): os.makedirs('temp/gt', exist_ok=True) os.makedirs('temp/result', exist_ok=True) save_mp4(gt_imgs, "temp/gt/gt.mp4") save_mp4(results, "temp/result/result.mp4") evaluator = fvd.cdfvd('i3d', ckpt_path=None, device='cuda', n_real=1, n_fake=1) evaluator.compute_real_stats(evaluator.load_videos('temp/gt', data_type='video_folder')) evaluator.compute_fake_stats(evaluator.load_videos('temp/result', data_type='video_folder')) score = evaluator.compute_fvd_from_stats() evaluator.offload_model_to_cpu() print(score) return score def compute_fidx(item, gt_imgs, results): os.makedirs('temp/'+item+'_gt', exist_ok=True) os.makedirs('temp/'+item, exist_ok=True) c = 0 for img in gt_imgs: img.save('temp/'+item+'_gt/'+str(c)+'.png') c = c+1 c = 0 for img in gt_imgs: img.save('temp/'+item+'/'+str(c)+'.png') c = c+1 fid = calculate_fid_given_paths(['temp/'+item+'_gt', 'temp/'+item], batch_size=8, device='cuda', dims=2048) return fid # FID: Save images to temp folders for FID calculation def compute_fid(img1, img2): os.makedirs('temp/img1', exist_ok=True) os.makedirs('temp/img2', exist_ok=True) img1.save('temp/img1/0.png') img2.save('temp/img2/0.png') fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cuda', dims=2048) return fid def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False): images = [] if not isinstance(image_paths[0], str): images = image_paths else: for path in image_paths: print(path) img = Image.open(path) images.append([img]) results = [] results_base = [] gt_frames = [] max_frame_count = 200 if os.path.isdir('/data/out/'+item): for filename in os.listdir('/data/out/'+item): img = Image.open('/data/out/'+item+'/'+filename) if filename.startswith('result_'): results.append(img) elif filename.startswith('base_'): results_base.append(img) elif filename.startswith('frame_'): gt_frames.append(img) #results = results[:max_frame_count] #results_base = results_base[:max_frame_count] #gt_frames = gt_frames[:max_frame_count] else: gt_frames = extract_frames(video_path, fps) gt_frames = gt_frames[:max_frame_count] for f in gt_frames: f.thumbnail((512,512)) #results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True) results, results_base = run_eval(images, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False) os.makedirs('/data/out/'+item, exist_ok=True) for i, frame in enumerate(gt_frames): frame.save("/data/out/"+item+"/frame_"+str(i)+".png") for i, result in enumerate(results): result.save("/data/out/"+item+"/result_"+str(i)+".png") for i, result in enumerate(results_base): result.save("/data/out/"+item+"/base_"+str(i)+".png") ssim = [] psnr = [] lpips = [] fid = [] ssim2 = [] psnr2 = [] lpips2 = [] fid2 = [] c = 0 #print(len(gt_frames), len(results), len(results_base)) for gt, result, base in zip(gt_frames, results, results_base): ssim.append(float(compute_ssim(gt, result))) psnr.append(float(compute_psnr(gt, result))) lpips.append(float(compute_lpips(gt, result))) ssim2.append(float(compute_ssim(gt, base))) psnr2.append(float(compute_psnr(gt, base))) lpips2.append(float(compute_lpips(gt, base))) if c<50: print(c) fid.append(float(compute_fid(gt, result))) fid2.append(float(compute_fid(gt, base))) c = c+1 #fvd = float(compute_fvd(item, gt_frames, results)) #fvd2 = float(compute_fvd(item, gt_frames, results_base)) print("SSIM:", sum(ssim)/len(ssim)) print("PSNR:", sum(psnr)/len(psnr)) print("LPIPS:", sum(lpips)/len(lpips)) print("FID:", sum(fid)/len(fid)) #print("FVD:", fvd) print('baseline:') print("SSIM:", sum(ssim2)/len(ssim2)) print("PSNR:", sum(psnr2)/len(psnr2)) print("LPIPS:", sum(lpips2)/len(lpips2)) print("FID:", sum(fid2)/len(fid2)) #print("FVD:", fvd2) metrics = {} metrics[item] = {'ft': {}, 'base': {}, 'n_frames': len(gt_frames), 'complexity': len(images)} metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim} metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr} metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips} metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid} #metrics[item]['ft']['fvd'] = fvd metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2} metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2} metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2} metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2} #metrics[item]['base']['fvd'] = fvd2 #print(metrics) return metrics[item] def get_files(directory_path): """ Returns a list of all files in the specified directory. """ files = [] for entry in os.listdir(directory_path): full_path = os.path.join(directory_path, entry) if os.path.isfile(full_path): files.append(entry) return files def run_evaluate(): print("run_evaluate") snapshot_download(repo_id="acmyu/KeyframesAI-eval", local_dir="test", repo_type="dataset") with open('/data/metrics.json', 'r') as file: metrics = json.load(file) items = os.listdir('test') items = [it for it in items if not it[0]=='.' and not os.path.isfile('test/'+it)] print(items) #items = ['sidewalk'] #['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody'] for item in items: if item in metrics: continue print(item) try: files = get_files('test/'+item) images = list(filter(lambda x: not x.endswith('.mp4'), files)) images = ['test/'+item+'/'+img for img in images] videos = [x for x in files if x.endswith('.mp4')] print(images, videos) if len(videos) == 1: metrics[item] = get_score(item, images, 'test/'+item+'/'+videos[0]) #get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4') with open('/data/metrics.json', "w", encoding="utf-8") as json_file: json.dump(metrics, json_file, ensure_ascii=False, indent=4) else: print('Error: mp4 not found') except Exception as e: print("Error", item, e) ssim = [] psnr = [] lpips = [] fid = [] ssim2 = [] psnr2 = [] lpips2 = [] fid2 = [] for item in metrics.keys(): ssim.append(metrics[item]['ft']['ssim']['avg']) psnr.append(metrics[item]['ft']['psnr']['avg']) lpips.append(metrics[item]['ft']['lpips']['avg']) fid.append(metrics[item]['ft']['fid']['avg']) ssim2.append(metrics[item]['base']['ssim']['avg']) psnr2.append(metrics[item]['base']['psnr']['avg']) lpips2.append(metrics[item]['base']['lpips']['avg']) fid2.append(metrics[item]['base']['fid']['avg']) print(item) print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg']) print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg']) print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg']) print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg']) print('Results:') print("SSIM:", sum(ssim)/len(ssim)) print("PSNR:", sum(psnr)/len(psnr)) print("LPIPS:", sum(lpips)/len(lpips)) print("FID:", sum(fid)/len(fid)) print('baseline:') print("SSIM:", sum(ssim2)/len(ssim2)) print("PSNR:", sum(psnr2)/len(psnr2)) print("LPIPS:", sum(lpips2)/len(lpips2)) print("FID:", sum(fid2)/len(fid2))