|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint |
|
from collections import OrderedDict |
|
from .outputs import BaseOutput |
|
from model.diffusers.models.unet_2d_condition import UNet2DConditionModel |
|
|
|
all_feature_dic={} |
|
|
|
def clear_feature_dic(): |
|
global all_feature_dic |
|
all_feature_dic={} |
|
all_feature_dic["low"]=[] |
|
all_feature_dic["mid"]=[] |
|
all_feature_dic["high"]=[] |
|
all_feature_dic["highest"]=[] |
|
def get_feature_dic(): |
|
global all_feature_dic |
|
return all_feature_dic |
|
|
|
|
|
class UNet2DConditionOutput(BaseOutput): |
|
""" |
|
Args: |
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
|
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. |
|
""" |
|
|
|
sample: torch.FloatTensor |
|
|
|
|
|
class UNet2D(UNet2DConditionModel): |
|
|
|
|
|
def forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
"""r |
|
Args: |
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
|
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps |
|
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When |
|
returning a tuple, the first element is the sample tensor. |
|
""" |
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
|
timesteps = timesteps.to(dtype=torch.float32) |
|
timesteps = timesteps[None].to(device=sample.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if timesteps[0]==0 or timesteps[0]==1: |
|
flag_time=True |
|
else: |
|
flag_time=False |
|
|
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
emb = self.time_embedding(t_emb) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
if flag_time: |
|
reshape_h=sample.reshape(int(sample.size()[0]/2),int(sample.size()[1]*2),sample.size()[2],sample.size()[3]) |
|
if reshape_h.size()[2]==8: |
|
all_feature_dic["low"].append(reshape_h) |
|
elif reshape_h.size()[2]==16: |
|
all_feature_dic["mid"].append(reshape_h) |
|
elif reshape_h.size()[2]==32: |
|
all_feature_dic["high"].append(reshape_h) |
|
elif reshape_h.size()[2]==64: |
|
all_feature_dic["highest"].append(reshape_h) |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: |
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
if flag_time: |
|
for h in res_samples: |
|
reshape_h=h.reshape(int(h.size()[0]/2),int(h.size()[1]*2),h.size()[2],h.size()[3]) |
|
|
|
if reshape_h.size()[2]==8: |
|
all_feature_dic["low"].append(reshape_h) |
|
elif reshape_h.size()[2]==16: |
|
all_feature_dic["mid"].append(reshape_h) |
|
elif reshape_h.size()[2]==32: |
|
all_feature_dic["high"].append(reshape_h) |
|
elif reshape_h.size()[2]==64: |
|
all_feature_dic["highest"].append(reshape_h) |
|
down_block_res_samples += res_samples |
|
|
|
|
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) |
|
if flag_time: |
|
reshape_h=sample.reshape(int(sample.size()[0]/2),int(sample.size()[1]*2),sample.size()[2],sample.size()[3]) |
|
if reshape_h.size()[2]==8: |
|
all_feature_dic["low"].append(reshape_h) |
|
elif reshape_h.size()[2]==16: |
|
all_feature_dic["mid"].append(reshape_h) |
|
elif reshape_h.size()[2]==32: |
|
all_feature_dic["high"].append(reshape_h) |
|
elif reshape_h.size()[2]==64: |
|
all_feature_dic["highest"].append(reshape_h) |
|
|
|
|
|
for upsample_block in self.up_blocks: |
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: |
|
sample, up_samples = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
) |
|
else: |
|
sample, up_samples = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) |
|
|
|
if flag_time: |
|
for h in up_samples: |
|
reshape_h=h.reshape(int(h.size()[0]/2),int(h.size()[1]*2),h.size()[2],h.size()[3]) |
|
|
|
if reshape_h.size()[2]==8: |
|
all_feature_dic["low"].append(reshape_h) |
|
elif reshape_h.size()[2]==16: |
|
all_feature_dic["mid"].append(reshape_h) |
|
elif reshape_h.size()[2]==32: |
|
all_feature_dic["high"].append(reshape_h) |
|
elif reshape_h.size()[2]==64: |
|
all_feature_dic["highest"].append(reshape_h) |
|
|
|
|
|
|
|
|
|
sample = self.conv_norm_out(sample.float()).type(sample.dtype) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
|
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |