File size: 1,592 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from typing import List

class BaseConditioner(nn.Module):
    def __init__(self):
        super(BaseConditioner, self).__init__()

    def _impl_condition(self, y, metadata)->torch.Tensor:
        raise NotImplementedError()

    def _impl_uncondition(self, y, metadata)->torch.Tensor:
        raise NotImplementedError()

    @torch.no_grad()
    @torch.autocast("cuda", dtype=torch.bfloat16)
    def __call__(self, y, metadata:dict={}):
        condition = self._impl_condition(y, metadata)
        uncondition = self._impl_uncondition(y, metadata)
        if condition.dtype in [torch.float64, torch.float32, torch.float16]:
            condition = condition.to(torch.bfloat16)
        if uncondition.dtype in [torch.float64,torch.float32, torch.float16]:
            uncondition = uncondition.to(torch.bfloat16)
        return condition, uncondition


class ComposeConditioner(BaseConditioner):
    def __init__(self, conditioners:List[BaseConditioner]):
        super().__init__()
        self.conditioners = conditioners

    def _impl_condition(self, y, metadata):
        condition = []
        for conditioner in self.conditioners:
            condition.append(conditioner._impl_condition(y, metadata))
        condition = torch.cat(condition, dim=1)
        return condition

    def _impl_uncondition(self, y, metadata):
        uncondition = []
        for conditioner in self.conditioners:
            uncondition.append(conditioner._impl_uncondition(y, metadata))
        uncondition = torch.cat(uncondition, dim=1)
        return uncondition