|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import yaml
|
|
|
|
"""
|
|
Script to create the plots in 'Progressive Compression with Universally Quantized Diffusion Models', Yang et al., 2025.
|
|
"""
|
|
|
|
|
|
def rd_fid(dataset='imagenet', baselines=None):
|
|
"""
|
|
Create R-D and R-FID curves from precomputed models nad baseline,
|
|
evaluated on 10,000 images from the de-duplicated evaluation set.
|
|
|
|
Inputs:
|
|
-------
|
|
dataset: 'cifar' or 'imagenet'
|
|
baselines: (optional) list of baselines from
|
|
'uqdm' or 'uqdm-d', 'uqdm-a', 'uqdm-f' - our model via (d)enoising, (a)ncestral, or (f)low-based sampling
|
|
'vdm' or 'vdm-d', 'vdm-a', 'vdm-f', 'vdm-1000d' - theoretical results of Gaussian diffusion
|
|
'jpeg', 'jpeg2000', 'bpg' - wavelet-based traditional codecs
|
|
'ctc', - progressive neural codec via hierarchically quantized latent space (Jeon et al., 2023)
|
|
'cdc' or 'cdc-0', 'cdc-p' - non-progressive neural codec with conditional diffusion model (Yang et al., 2023)
|
|
'vae' or 'vae-b', 'vae-m' - non-progressive neural codec with VAE (Ballé et al., 2018) or (Minnen et al., 2020)
|
|
save: (optional) filename to save plot to
|
|
"""
|
|
|
|
|
|
with open('%s.yml' % dataset, 'r') as f:
|
|
results = yaml.safe_load(f)
|
|
if baselines is None:
|
|
baselines = ['jpeg', 'jpeg2000', 'bpg', 'ctc', 'cdc', 'vdm', 'uqdm']
|
|
if 'cdc' in baselines:
|
|
baselines += ['cdc-0', 'cdc-p']
|
|
if 'vae' in baselines:
|
|
baselines += ['vae-b', 'vae-m']
|
|
if 'vdm' in baselines:
|
|
baselines += ['vdm-1000d', 'vdm-d', 'vdm-a', 'vdm-f']
|
|
if 'uqdm' in baselines:
|
|
baselines += ['uqdm-d', 'uqdm-a', 'uqdm-f']
|
|
baselines = [b for b in baselines if b not in ['uqdm', 'vdm', 'cdc', 'vae'] and b in results.keys()]
|
|
|
|
|
|
pl_kwargs = {'alpha': 0.8, 'lw': 2}
|
|
pl_styles = {
|
|
'uqdm-d': dict(ls='-+', color='darkorange', label='UQDM T=4, denoise'),
|
|
'uqdm-a': dict(ls='--x', color='darkorange', label='UQDM T=4, ancestral'),
|
|
'uqdm-f': dict(ls=':x', color='darkorange', label='UQDM T=4, flow-based'),
|
|
'vdm-d': dict(ls='-+', color='blue', label='VDM T=20, denoise', alpha=0.6, lw=1.5),
|
|
'vdm-a': dict(ls='--x', color='blue', label='VDM T=20, ancestral', alpha=0.6, lw=1.5),
|
|
'vdm-f': dict(ls=':x', color='blue', label='VDM T=20, flow-based', alpha=0.6, lw=1.5),
|
|
'vdm-1000d': dict(ls=':+', color='darkturquoise', label='VDM T=1000, denoise', alpha=0.6, lw=1.5),
|
|
'jpeg': dict(ls='-.+', color='red', label='JPEG'),
|
|
'jpeg2000': dict(ls='-x', color='red', label='JPEG2000'),
|
|
'bpg': dict(ls='-x', color='sienna', label='BPG'),
|
|
'ctc': dict(ls='-x', color='fuchsia', label='CTC'),
|
|
'cdc-0': dict(ls='-x', color='green', label='CDC (p=0)'),
|
|
'cdc-p': dict(ls='-.x', color='green', label='CDC (p=0.9)'),
|
|
'vae-b': dict(ls='--x', color='limegreen', label='VAE (Ballé 2018)'),
|
|
'vae-m': dict(ls='-+', color='limegreen', label='VAE (Minnen 2020)'),
|
|
}
|
|
sns.set_style('whitegrid')
|
|
|
|
|
|
textwidth = 5.5206 * 2.5
|
|
fig_rd, ax_rd = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
|
|
fig_fid, ax_fid = plt.subplots(figsize=(0.45 * textwidth, 0.36 * textwidth))
|
|
for b in baselines:
|
|
bpp, psnr, fid = results[b]['bpp'], results[b]['psnr'], results[b]['fid']
|
|
kwargs = pl_kwargs | pl_styles[b]
|
|
ls = kwargs.pop('ls', None)
|
|
ax_rd.plot(bpp, psnr, ls, **kwargs)
|
|
ax_fid.plot(bpp, fid, ls, **kwargs)
|
|
ax_rd.legend(loc='lower right')
|
|
ax_fid.legend(loc='upper right')
|
|
ax_rd.set(xlabel='Rate (bpp)', ylabel='PSNR (dB)')
|
|
ax_fid.set(xlabel='Rate (bpp)', ylabel='FID')
|
|
ax_rd.grid(visible=True)
|
|
ax_fid.grid(visible=True)
|
|
ax_fid.set_yscale('symlog')
|
|
fig_rd.tight_layout()
|
|
fig_fid.tight_layout()
|
|
fig_rd.savefig('tmp_rd.png', bbox_inches='tight', pad_inches=0, dpi=600)
|
|
fig_fid.savefig('tmp_fid.png', bbox_inches='tight', pad_inches=0, dpi=600)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
rd_fid(dataset='cifar')
|
|
rd_fid(dataset='imagenet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|