|
import torch |
|
import numpy as np |
|
from ....utils.general_utils import dict_foreach |
|
from ....pipelines import samplers |
|
|
|
|
|
class ClassifierFreeGuidanceMixin: |
|
def __init__(self, *args, p_uncond: float = 0.1, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.p_uncond = p_uncond |
|
|
|
def get_cond(self, cond, neg_cond=None, **kwargs): |
|
""" |
|
Get the conditioning data. |
|
""" |
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" |
|
|
|
if self.p_uncond > 0: |
|
|
|
def get_batch_size(cond): |
|
if isinstance(cond, torch.Tensor): |
|
return cond.shape[0] |
|
elif isinstance(cond, list): |
|
return len(cond) |
|
else: |
|
raise ValueError(f"Unsupported type of cond: {type(cond)}") |
|
|
|
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] |
|
B = get_batch_size(ref_cond) |
|
|
|
def select(cond, neg_cond, mask): |
|
if isinstance(cond, torch.Tensor): |
|
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) |
|
return torch.where(mask, neg_cond, cond) |
|
elif isinstance(cond, list): |
|
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] |
|
else: |
|
raise ValueError(f"Unsupported type of cond: {type(cond)}") |
|
|
|
mask = list(np.random.rand(B) < self.p_uncond) |
|
if not isinstance(cond, dict): |
|
cond = select(cond, neg_cond, mask) |
|
else: |
|
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) |
|
|
|
return cond |
|
|
|
def get_inference_cond(self, cond, neg_cond=None, **kwargs): |
|
""" |
|
Get the conditioning data for inference. |
|
""" |
|
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" |
|
return {'cond': cond, 'neg_cond': neg_cond, **kwargs} |
|
|
|
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: |
|
""" |
|
Get the sampler for the diffusion process. |
|
""" |
|
return samplers.FlowEulerCfgSampler(self.sigma_min) |
|
|