Spaces:
Running
on
Zero
Running
on
Zero
from typing import Union, List | |
import torch | |
import torch.nn as nn | |
from typing import Callable | |
from src.diffusion.base.scheduling import BaseScheduler | |
class BaseSampler(nn.Module): | |
def __init__(self, | |
scheduler: BaseScheduler = None, | |
guidance_fn: Callable = None, | |
num_steps: int = 250, | |
guidance: Union[float, List[float]] = 1.0, | |
*args, | |
**kwargs | |
): | |
super(BaseSampler, self).__init__() | |
self.num_steps = num_steps | |
self.guidance = guidance | |
self.guidance_fn = guidance_fn | |
self.scheduler = scheduler | |
def _impl_sampling(self, net, noise, condition, uncondition): | |
raise NotImplementedError | |
def forward(self, net, noise, condition, uncondition, return_x_trajs=False, return_v_trajs=False): | |
x_trajs, v_trajs = self._impl_sampling(net, noise, condition, uncondition) | |
if return_x_trajs and return_v_trajs: | |
return x_trajs[-1], x_trajs, v_trajs | |
elif return_x_trajs: | |
return x_trajs[-1], x_trajs | |
elif return_v_trajs: | |
return x_trajs[-1], v_trajs | |
else: | |
return x_trajs[-1] | |