ReFlex / src /callback /callback_fn.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import torch
import torch.nn.functional as F
from diffusers.callbacks import PipelineCallback
from scipy.ndimage import binary_dilation
from skimage.filters import threshold_otsu
from ..attn_utils.mask_utils import get_mask
class CallbackLatentStore(PipelineCallback):
tensor_inputs = ['latents']
def __init__(self):
self.latents = []
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
self.latents.append(callback_kwargs['latents'])
return callback_kwargs
class CallbackAll(PipelineCallback):
tensor_inputs = ['latents']
def __init__(
self,
latents,
attn_collector,
feature_collector,
feature_inject_steps,
mid_step_index=0,
step_start=0,
use_mask=False,
use_ca_mask=False,
source_ca_index=None,
target_ca_index=None,
mask_steps=18,
mask_kwargs={},
mask=None,
):
self.latents = latents
self.attn_collector = attn_collector
self.feature_collector = feature_collector
self.feature_inject_steps = feature_inject_steps
self.mid_step_index = mid_step_index
self.step_start = step_start
self.mask = mask
self.mask_steps = mask_steps
self.use_mask = use_mask
self.use_ca_mask = use_ca_mask
self.source_ca_index = source_ca_index
self.target_ca_index = target_ca_index
self.mask_kwargs = mask_kwargs
def latent_blend(self, s, t, mask):
return s * (1-mask) + t * mask
# return s * mask.logical_not() + t * mask
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
cur_step = step_index + self.step_start
if self.latents is None:
pass
elif cur_step < self.mid_step_index:
inject_latent = self.latents[self.mid_step_index]
callback_kwargs['latents'][:1] = inject_latent
if self.use_mask:
if self.use_ca_mask:
if self.source_ca_index is not None:
source_ca = self.attn_collector.controller.source_ca
mask = get_mask(source_ca, self.source_ca_index, **self.mask_kwargs)
elif self.target_ca_index is not None:
if cur_step < 5:
return callback_kwargs
target_ca = self.attn_collector.controller.target_ca
mask = get_mask(target_ca, self.target_ca_index, **self.mask_kwargs)
self.mask = mask
elif self.mask is not None:
mask = self.mask
else:
return callback_kwargs
if (cur_step < self.mask_steps):
mask = mask.to(pipeline.dtype)
target_latent = callback_kwargs['latents'][1:]
blend_latent = self.latents[cur_step+1]
# if cur_step + 1 < self.mid_step_index:
# blend_latent = self.latents[cur_step+1]
# else:
# blend_latent = callback_kwargs['latents'][:1]
new_latent = self.latent_blend(
pipeline._unpack_latents(blend_latent, 1024, 1024, pipeline.vae_scale_factor),
pipeline._unpack_latents(target_latent, 1024, 1024, pipeline.vae_scale_factor),
mask
)
new_latent = pipeline._pack_latents(new_latent, *new_latent.shape)
callback_kwargs['latents'][1:] = new_latent
return callback_kwargs