Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional, Tuple | |
import torch | |
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX | |
class TruncAdaLayerNorm(AdaLayerNorm): | |
def forward( | |
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None | |
) -> torch.Tensor: | |
batch_size = x.shape[0] | |
return self.forward_old( | |
x, | |
temb[:batch_size] if temb is not None else None, | |
) | |
class TruncAdaLayerNormContinuous(AdaLayerNormContinuous): | |
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: | |
batch_size = x.shape[0] | |
return self.forward_old(x, conditioning_embedding[:batch_size]) | |
class TruncAdaLayerNormZero(AdaLayerNormZero): | |
def forward( | |
self, | |
x: torch.Tensor, | |
timestep: Optional[torch.Tensor] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
hidden_dtype: Optional[torch.dtype] = None, | |
emb: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
batch_size = x.shape[0] | |
return self.forward_old( | |
x, | |
timestep[:batch_size] if timestep is not None else None, | |
class_labels[:batch_size] if class_labels is not None else None, | |
hidden_dtype, | |
emb[:batch_size] if emb is not None else None, | |
) | |
class TruncSD35AdaLayerNormZeroX(SD35AdaLayerNormZeroX): | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
emb: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, ...]: | |
batch_size = hidden_states.shape[0] | |
return self.forward_old( | |
hidden_states, | |
emb[:batch_size] if emb is not None else None, | |
) | |