Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,592 Bytes
56238f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
import torch.nn as nn
from typing import List
class BaseConditioner(nn.Module):
def __init__(self):
super(BaseConditioner, self).__init__()
def _impl_condition(self, y, metadata)->torch.Tensor:
raise NotImplementedError()
def _impl_uncondition(self, y, metadata)->torch.Tensor:
raise NotImplementedError()
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.bfloat16)
def __call__(self, y, metadata:dict={}):
condition = self._impl_condition(y, metadata)
uncondition = self._impl_uncondition(y, metadata)
if condition.dtype in [torch.float64, torch.float32, torch.float16]:
condition = condition.to(torch.bfloat16)
if uncondition.dtype in [torch.float64,torch.float32, torch.float16]:
uncondition = uncondition.to(torch.bfloat16)
return condition, uncondition
class ComposeConditioner(BaseConditioner):
def __init__(self, conditioners:List[BaseConditioner]):
super().__init__()
self.conditioners = conditioners
def _impl_condition(self, y, metadata):
condition = []
for conditioner in self.conditioners:
condition.append(conditioner._impl_condition(y, metadata))
condition = torch.cat(condition, dim=1)
return condition
def _impl_uncondition(self, y, metadata):
uncondition = []
for conditioner in self.conditioners:
uncondition.append(conditioner._impl_uncondition(y, metadata))
uncondition = torch.cat(uncondition, dim=1)
return uncondition |