PixNerd / src /diffusion /base /sampling.py
wangshuai6
init
56238f0
from typing import Union, List
import torch
import torch.nn as nn
from typing import Callable
from src.diffusion.base.scheduling import BaseScheduler
class BaseSampler(nn.Module):
def __init__(self,
scheduler: BaseScheduler = None,
guidance_fn: Callable = None,
num_steps: int = 250,
guidance: Union[float, List[float]] = 1.0,
*args,
**kwargs
):
super(BaseSampler, self).__init__()
self.num_steps = num_steps
self.guidance = guidance
self.guidance_fn = guidance_fn
self.scheduler = scheduler
def _impl_sampling(self, net, noise, condition, uncondition):
raise NotImplementedError
@torch.autocast("cuda", dtype=torch.bfloat16)
def forward(self, net, noise, condition, uncondition, return_x_trajs=False, return_v_trajs=False):
x_trajs, v_trajs = self._impl_sampling(net, noise, condition, uncondition)
if return_x_trajs and return_v_trajs:
return x_trajs[-1], x_trajs, v_trajs
elif return_x_trajs:
return x_trajs[-1], x_trajs
elif return_v_trajs:
return x_trajs[-1], v_trajs
else:
return x_trajs[-1]