Spaces:
Runtime error
Runtime error
import gc | |
from collections import OrderedDict | |
from typing import Any, Dict, Callable | |
import os | |
from copy import deepcopy | |
from math import ceil | |
import json | |
import safetensors | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers import ( | |
DiffusionPipeline, | |
StableDiffusionPipeline, | |
StableDiffusionXLPipeline, | |
DDPMScheduler, | |
UNet2DConditionModel, | |
) | |
import tqdm | |
import yaml | |
def remove_all_forward_hooks(model: torch.nn.Module) -> None: | |
for _name, child in model._modules.items(): # pylint: disable=protected-access | |
if child is not None: | |
if hasattr(child, "_forward_hooks"): | |
child._forward_hooks: Dict[int, Callable] = OrderedDict() | |
remove_all_forward_hooks(child) | |
# Inspired from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock | |
class SparseMoeBlock(nn.Module): | |
def __init__(self, config, experts): | |
super().__init__() | |
self.hidden_dim = config["hidden_size"] | |
self.num_experts = config["num_local_experts"] | |
self.top_k = config["num_experts_per_tok"] | |
self.out_dim = config.get("out_dim", self.hidden_dim) | |
# gating | |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | |
self.experts = nn.ModuleList([deepcopy(exp) for exp in experts]) | |
def forward(self, hidden_states: torch.Tensor, scale=None) -> torch.Tensor: # pylint: disable=unused-argument | |
batch_size, sequence_length, f_map_sz = hidden_states.shape | |
hidden_states = hidden_states.view(-1, f_map_sz) | |
# router_logits: (batch * sequence_length, n_experts) | |
router_logits = self.gate(hidden_states) | |
_, selected_experts = torch.topk( | |
router_logits.sum(dim=0, keepdim=True), self.top_k, dim=1 | |
) | |
routing_weights = F.softmax( | |
router_logits[:, selected_experts[0]], dim=1, dtype=torch.float | |
) | |
# we cast back to the input dtype | |
routing_weights = routing_weights.to(hidden_states.dtype) | |
final_hidden_states = torch.zeros( | |
(batch_size * sequence_length, self.out_dim), | |
dtype=hidden_states.dtype, | |
device=hidden_states.device, | |
) | |
# Loop over all available experts in the model and perform the computation on each expert | |
for i, expert_idx in enumerate(selected_experts[0].tolist()): | |
expert_layer = self.experts[expert_idx] | |
current_hidden_states = routing_weights[:, i].view( | |
batch_size * sequence_length, -1 | |
) * expert_layer(hidden_states) | |
# However `index_add_` only support torch tensors for indexing so we'll use | |
# the `top_x` tensor here. | |
final_hidden_states = final_hidden_states + current_hidden_states | |
final_hidden_states = final_hidden_states.reshape( | |
batch_size, sequence_length, self.out_dim | |
) | |
return final_hidden_states | |
def getActivation(activation, name): | |
def hook(model, inp, output): # pylint: disable=unused-argument | |
activation[name] = inp | |
return hook | |
class SegMoEPipeline: | |
def __init__(self, config_or_path, **kwargs) -> Any: | |
""" | |
Instantiates the SegMoEPipeline. SegMoEPipeline implements the Segmind Mixture of Diffusion Experts, efficiently combining Stable Diffusion and Stable Diffusion Xl models. | |
Usage: | |
from segmoe import SegMoEPipeline | |
pipeline = SegMoEPipeline(config_or_path, **kwargs) | |
config_or_path: Path to Config or Directory containing SegMoE checkpoint or HF Card of SegMoE Checkpoint. | |
Other Keyword Arguments: | |
torch_dtype: Data Type to load the pipeline in. (Default: torch.float16) | |
variant: Variant of the Model. (Default: fp16) | |
device: Device to load the model on. (Default: cuda) | |
Other args supported by diffusers.DiffusionPipeline are also supported. | |
For more details visit https://github.com/segmind/segmoe. | |
""" | |
self.torch_dtype = kwargs.pop("torch_dtype", torch.float16) | |
self.use_safetensors = kwargs.pop("use_safetensors", True) | |
self.variant = kwargs.pop("variant", "fp16") | |
self.device = kwargs.pop("device", "cuda") | |
if os.path.isfile(config_or_path): | |
self.load_from_scratch(config_or_path, **kwargs) | |
else: | |
if not os.path.isdir(config_or_path): | |
cached_folder = DiffusionPipeline.download(config_or_path) | |
else: | |
cached_folder = config_or_path | |
unet = self.create_empty(cached_folder) | |
unet.load_state_dict( | |
safetensors.torch.load_file( | |
f"{cached_folder}/unet/diffusion_pytorch_model.safetensors" | |
) | |
) | |
self.pipe = DiffusionPipeline.from_pretrained( | |
cached_folder, | |
unet=unet, | |
torch_dtype=self.torch_dtype, | |
use_safetensors=self.use_safetensors, | |
) | |
self.pipe.to(self.device) | |
self.pipe.unet.to( | |
device=self.device, | |
dtype=self.torch_dtype, | |
memory_format=torch.channels_last, | |
) | |
def to(self, *args, **kwargs): # TODO added no-op to avoid error | |
self.pipe.to(*args, **kwargs) | |
def load_from_scratch(self, config: str, **kwargs) -> None: | |
# Load Config | |
with open(config, "r", encoding='utf8') as f: | |
config = yaml.load(f, Loader=yaml.SafeLoader) | |
self.config = config | |
if self.config.get("num_experts", None): | |
self.num_experts = self.config["num_experts"] | |
else: | |
if self.config.get("experts", None): | |
self.num_experts = len(self.config["experts"]) | |
else: | |
if self.config.get("loras", None): | |
self.num_experts = len(self.config["loras"]) | |
else: | |
self.num_experts = 1 | |
num_experts_per_tok = self.config.get("num_experts_per_tok", 1) | |
self.config["num_experts_per_tok"] = num_experts_per_tok | |
moe_layers = self.config.get("moe_layers", "attn") | |
self.config["moe_layers"] = moe_layers | |
# Load Base Model | |
if self.config["base_model"].startswith( | |
"https://civitai.com/api/download/models/" | |
): | |
os.makedirs("base", exist_ok=True) | |
if not os.path.isfile("base/model.safetensors"): | |
os.system( | |
"wget -O " | |
+ "base/model.safetensors" | |
+ self.config["base_model"] | |
+ " --content-disposition" | |
) | |
self.config["base_model"] = "base/model.safetensors" | |
self.pipe = DiffusionPipeline.from_single_file( | |
self.config["base_model"], torch_dtype=self.torch_dtype | |
) | |
else: | |
try: | |
self.pipe = DiffusionPipeline.from_pretrained( | |
self.config["base_model"], | |
torch_dtype=self.torch_dtype, | |
use_safetensors=self.use_safetensors, | |
variant=self.variant, | |
**kwargs, | |
) | |
except Exception: | |
self.pipe = DiffusionPipeline.from_pretrained( | |
self.config["base_model"], torch_dtype=self.torch_dtype, **kwargs | |
) | |
if self.pipe.__class__ == StableDiffusionPipeline: | |
self.up_idx_start = 1 | |
self.up_idx_end = len(self.pipe.unet.up_blocks) | |
self.down_idx_start = 0 | |
self.down_idx_end = len(self.pipe.unet.down_blocks) - 1 | |
elif self.pipe.__class__ == StableDiffusionXLPipeline: | |
self.up_idx_start = 0 | |
self.up_idx_end = len(self.pipe.unet.up_blocks) - 1 | |
self.down_idx_start = 1 | |
self.down_idx_end = len(self.pipe.unet.down_blocks) | |
self.config["up_idx_start"] = self.up_idx_start | |
self.config["up_idx_end"] = self.up_idx_end | |
self.config["down_idx_start"] = self.down_idx_start | |
self.config["down_idx_end"] = self.down_idx_end | |
# TODO: Add Support for Scheduler Selection | |
self.pipe.scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config) | |
# Load Experts | |
experts = [] | |
positive = [] | |
negative = [] | |
if self.config.get("experts", None): | |
for i, exp in enumerate(self.config["experts"]): | |
positive.append(exp["positive_prompt"]) | |
negative.append(exp["negative_prompt"]) | |
if exp["source_model"].startswith( | |
"https://civitai.com/api/download/models/" | |
): | |
try: | |
if not os.path.isfile(f"expert_{i}/model.safetensors"): | |
os.makedirs(f"expert_{i}", exist_ok=True) | |
if not os.path.isfile(f"expert_{i}/model.safetensors"): | |
os.system( | |
f"wget {exp['source_model']} -O " | |
+ f"expert_{i}/model.safetensors" | |
+ " --content-disposition" | |
) | |
exp["source_model"] = f"expert_{i}/model.safetensors" | |
expert = DiffusionPipeline.from_single_file( | |
exp["source_model"], | |
).to(self.device, self.torch_dtype) | |
except Exception as e: | |
print(f"Expert {i} {exp['source_model']} failed to load") | |
print("Error:", e) | |
else: | |
try: | |
expert = DiffusionPipeline.from_pretrained( | |
exp["source_model"], | |
torch_dtype=self.torch_dtype, | |
use_safetensors=self.use_safetensors, | |
variant=self.variant, | |
**kwargs, | |
) | |
# TODO: Add Support for Scheduler Selection | |
expert.scheduler = DDPMScheduler.from_config( | |
expert.scheduler.config | |
) | |
except Exception: | |
expert = DiffusionPipeline.from_pretrained( | |
exp["source_model"], torch_dtype=self.torch_dtype, **kwargs | |
) | |
expert.scheduler = DDPMScheduler.from_config( | |
expert.scheduler.config | |
) | |
if exp.get("loras", None): | |
for j, lora in enumerate(exp["loras"]): | |
if lora.get("positive_prompt", None): | |
positive[-1] += " " + lora["positive_prompt"] | |
if lora.get("negative_prompt", None): | |
negative[-1] += " " + lora["negative_prompt"] | |
if lora["source_model"].startswith( | |
"https://civitai.com/api/download/models/" | |
): | |
try: | |
os.makedirs(f"expert_{i}/lora_{i}", exist_ok=True) | |
if not os.path.isfile( | |
f"expert_{i}/lora_{i}/pytorch_lora_weights.safetensors" | |
): | |
os.system( | |
f"wget {lora['source_model']} -O " | |
+ f"expert_{i}/lora_{j}/pytorch_lora_weights.safetensors" | |
+ " --content-disposition" | |
) | |
lora["source_model"] = f"expert_{j}/lora_{j}" | |
expert.load_lora_weights(lora["source_model"]) | |
if len(exp["loras"]) == 1: | |
expert.fuse_lora() | |
except Exception as e: | |
print( | |
f"Expert{i} LoRA {j} {lora['source_model']} failed to load" | |
) | |
print("Error:", e) | |
else: | |
expert.load_lora_weights(lora["source_model"]) | |
if len(exp["loras"]) == 1: | |
expert.fuse_lora() | |
experts.append(expert) | |
else: | |
experts = [deepcopy(self.pipe) for _ in range(self.num_experts)] | |
if self.config.get("experts", None): | |
if self.config.get("loras", None): | |
for i, lora in enumerate(self.config["loras"]): | |
if lora["source_model"].startswith( | |
"https://civitai.com/api/download/models/" | |
): | |
try: | |
os.makedirs(f"lora_{i}", exist_ok=True) | |
if not os.path.isfile( | |
f"lora_{i}/pytorch_lora_weights.safetensors" | |
): | |
os.system( | |
f"wget {lora['source_model']} -O " | |
+ f"lora_{i}/pytorch_lora_weights.safetensors" | |
+ " --content-disposition" | |
) | |
lora["source_model"] = f"lora_{i}" | |
self.pipe.load_lora_weights(lora["source_model"]) | |
if len(self.config["loras"]) == 1: | |
self.pipe.fuse_lora() | |
except Exception as e: | |
print(f"LoRA {i} {lora['source_model']} failed to load") | |
print("Error:", e) | |
else: | |
self.pipe.load_lora_weights(lora["source_model"]) | |
if len(self.config["loras"]) == 1: | |
self.pipe.fuse_lora() | |
else: | |
if self.config.get("loras", None): | |
j = [] | |
n_loras = len(self.config["loras"]) | |
i = 0 | |
positive = [""] * len(experts) | |
negative = [""] * len(experts) | |
while n_loras: | |
n = ceil(n_loras / len(experts)) | |
j += [i] * n | |
n_loras -= n | |
i += 1 | |
for i, lora in enumerate(self.config["loras"]): | |
positive[j[i]] += lora["positive_prompt"] + " " | |
negative[j[i]] += lora["negative_prompt"] + " " | |
if lora["source_model"].startswith( | |
"https://civitai.com/api/download/models/" | |
): | |
try: | |
os.makedirs(f"lora_{i}", exist_ok=True) | |
if not os.path.isfile( | |
f"lora_{i}/pytorch_lora_weights.safetensors" | |
): | |
os.system( | |
f"wget {lora['source_model']} -O " | |
+ f"lora_{i}/pytorch_lora_weights.safetensors" | |
+ " --content-disposition" | |
) | |
lora["source_model"] = f"lora_{i}" | |
experts[j[i]].load_lora_weights(lora["source_model"]) | |
experts[j[i]].fuse_lora() | |
except Exception: | |
print(f"LoRA {i} {lora['source_model']} failed to load") | |
else: | |
experts[j[i]].load_lora_weights(lora["source_model"]) | |
experts[j[i]].fuse_lora() | |
# Replace FF and Attention Layers with Sparse MoE Layers | |
for i in range(self.down_idx_start, self.down_idx_end): | |
for j in range(len(self.pipe.unet.down_blocks[i].attentions)): | |
for k in range( | |
len(self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks) | |
): | |
if not moe_layers == "attn": | |
config = { | |
"hidden_size": next( | |
self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff.parameters() | |
).size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
# FF Layers | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff = SparseMoeBlock(config, layers) | |
if not moe_layers == "ff": | |
## Attns | |
config = { | |
"hidden_size": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": self.num_experts, | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q = SparseMoeBlock(config, layers) | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k = SparseMoeBlock(config, layers) | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
"out_dim": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[0], | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[-1], | |
"out_dim": self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v | |
) | |
) | |
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v = SparseMoeBlock(config, layers) | |
for i in range(self.up_idx_start, self.up_idx_end): | |
for j in range(len(self.pipe.unet.up_blocks[i].attentions)): | |
for k in range( | |
len(self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks) | |
): | |
if not moe_layers == "attn": | |
config = { | |
"hidden_size": next( | |
self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff.parameters() | |
).size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
# FF Layers | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff = SparseMoeBlock(config, layers) | |
if not moe_layers == "ff": | |
# Attns | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[-1], | |
"out_dim": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[-1], | |
"out_dim": self.pipe.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": len(experts), | |
} | |
layers = [] | |
for l in range(len(experts)): | |
layers.append( | |
deepcopy( | |
experts[l] | |
.unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v | |
) | |
) | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v = SparseMoeBlock(config, layers) | |
# Routing Weight Initialization | |
if self.config.get("init", "hidden") == "hidden": | |
gate_params = self.get_gate_params(experts, positive, negative) | |
for i in range(self.down_idx_start, self.down_idx_end): | |
for j in range(len(self.pipe.unet.down_blocks[i].attentions)): | |
for k in range( | |
len( | |
self.pipe.unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks | |
) | |
): | |
# FF Layers | |
if not moe_layers == "attn": | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[k].ff.gate.weight = nn.Parameter( | |
gate_params[f"d{i}a{j}t{k}"] | |
) | |
# Attns | |
if not moe_layers == "ff": | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_q.gate.weight = nn.Parameter( | |
gate_params[f"sattnqd{i}a{j}t{k}"] | |
) | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_k.gate.weight = nn.Parameter( | |
gate_params[f"sattnkd{i}a{j}t{k}"] | |
) | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_v.gate.weight = nn.Parameter( | |
gate_params[f"sattnvd{i}a{j}t{k}"] | |
) | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_q.gate.weight = nn.Parameter( | |
gate_params[f"cattnqd{i}a{j}t{k}"] | |
) | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_k.gate.weight = nn.Parameter( | |
gate_params[f"cattnkd{i}a{j}t{k}"] | |
) | |
self.pipe.unet.down_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_v.gate.weight = nn.Parameter( | |
gate_params[f"cattnvd{i}a{j}t{k}"] | |
) | |
for i in range(self.up_idx_start, self.up_idx_end): | |
for j in range(len(self.pipe.unet.up_blocks[i].attentions)): | |
for k in range( | |
len( | |
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks | |
) | |
): | |
# FF Layers | |
if not moe_layers == "attn": | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[k].ff.gate.weight = nn.Parameter( | |
gate_params[f"u{i}a{j}t{k}"] | |
) | |
if not moe_layers == "ff": | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_q.gate.weight = nn.Parameter( | |
gate_params[f"sattnqu{i}a{j}t{k}"] | |
) | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_k.gate.weight = nn.Parameter( | |
gate_params[f"sattnku{i}a{j}t{k}"] | |
) | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn1.to_v.gate.weight = nn.Parameter( | |
gate_params[f"sattnvu{i}a{j}t{k}"] | |
) | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_q.gate.weight = nn.Parameter( | |
gate_params[f"cattnqu{i}a{j}t{k}"] | |
) | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_k.gate.weight = nn.Parameter( | |
gate_params[f"cattnku{i}a{j}t{k}"] | |
) | |
self.pipe.unet.up_blocks[i].attentions[ | |
j | |
].transformer_blocks[ | |
k | |
].attn2.to_v.gate.weight = nn.Parameter( | |
gate_params[f"cattnvu{i}a{j}t{k}"] | |
) | |
self.config["num_experts"] = len(experts) | |
remove_all_forward_hooks(self.pipe.unet) | |
try: | |
del experts | |
del expert | |
except Exception: | |
pass | |
# Move Model to Device | |
self.pipe.to(self.device) | |
self.pipe.unet.to( | |
device=self.device, | |
dtype=self.torch_dtype, | |
memory_format=torch.channels_last, | |
) | |
gc.collect() | |
torch.cuda.empty_cache() | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
""" | |
Inference the SegMoEPipeline. | |
Calls diffusers.DiffusionPipeline forward with the keyword arguments. See https://github.com/segmind/segmoe#usage for detailed usage. | |
""" | |
return self.pipe(*args, **kwds) | |
def create_empty(self, path): | |
with open(f"{path}/unet/config.json", encoding='utf8') as f: | |
config = json.load(f) | |
self.config = config["segmoe_config"] | |
unet = UNet2DConditionModel.from_config(config) | |
num_experts_per_tok = self.config["num_experts_per_tok"] | |
num_experts = self.config["num_experts"] | |
moe_layers = self.config["moe_layers"] | |
self.up_idx_start = self.config["up_idx_start"] | |
self.up_idx_end = self.config["up_idx_end"] | |
self.down_idx_start = self.config["down_idx_start"] | |
self.down_idx_end = self.config["down_idx_end"] | |
for i in range(self.down_idx_start, self.down_idx_end): | |
for j in range(len(unet.down_blocks[i].attentions)): | |
for k in range( | |
len(unet.down_blocks[i].attentions[j].transformer_blocks) | |
): | |
if not moe_layers == "attn": | |
config = { | |
"hidden_size": next( | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff.parameters() | |
).size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
# FF Layers | |
layers = [ | |
unet.down_blocks[i].attentions[j].transformer_blocks[k].ff | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff = SparseMoeBlock(config, layers) | |
if not moe_layers == "ff": | |
## Attns | |
config = { | |
"hidden_size": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q = SparseMoeBlock(config, layers) | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k = SparseMoeBlock(config, layers) | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
"out_dim": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[0], | |
} | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[-1], | |
"out_dim": unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.down_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v | |
] * num_experts | |
unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v = SparseMoeBlock(config, layers) | |
for i in range(self.up_idx_start, self.up_idx_end): | |
for j in range(len(unet.up_blocks[i].attentions)): | |
for k in range(len(unet.up_blocks[i].attentions[j].transformer_blocks)): | |
if not moe_layers == "attn": | |
config = { | |
"hidden_size": next( | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.ff.parameters() | |
).size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
# FF Layers | |
layers = [ | |
unet.up_blocks[i].attentions[j].transformer_blocks[k].ff | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff = SparseMoeBlock(config, layers) | |
if not moe_layers == "ff": | |
# Attns | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_q | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_k | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn1.to_v | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q.weight.size()[-1], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_q | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[-1], | |
"out_dim": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_k | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k = SparseMoeBlock(config, layers) | |
config = { | |
"hidden_size": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[-1], | |
"out_dim": unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v.weight.size()[0], | |
"num_experts_per_tok": num_experts_per_tok, | |
"num_local_experts": num_experts, | |
} | |
layers = [ | |
unet.up_blocks[i] | |
.attentions[j] | |
.transformer_blocks[k] | |
.attn2.to_v | |
] * num_experts | |
unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v = SparseMoeBlock(config, layers) | |
return unet | |
def save_pretrained(self, path): | |
""" | |
Save SegMoEPipeline to Disk. | |
Usage: | |
pipeline.save_pretrained(path) | |
Parameters: | |
path: Path to Directory to save the model in. | |
""" | |
for param in self.pipe.unet.parameters(): | |
param.data = param.data.contiguous() | |
self.pipe.unet.config["segmoe_config"] = self.config | |
self.pipe.save_pretrained(path) | |
safetensors.torch.save_file( | |
self.pipe.unet.state_dict(), | |
f"{path}/unet/diffusion_pytorch_model.safetensors", | |
) | |
def cast_hook(self, pipe, dicts): | |
for i in range(self.down_idx_start, self.down_idx_end): | |
for j in range(len(pipe.unet.down_blocks[i].attentions)): | |
for k in range( | |
len(pipe.unet.down_blocks[i].attentions[j].transformer_blocks) | |
): | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff.register_forward_hook(getActivation(dicts, f"d{i}a{j}t{k}")) | |
## Down Self Attns | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q.register_forward_hook( | |
getActivation(dicts, f"sattnqd{i}a{j}t{k}") | |
) | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k.register_forward_hook( | |
getActivation(dicts, f"sattnkd{i}a{j}t{k}") | |
) | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v.register_forward_hook( | |
getActivation(dicts, f"sattnvd{i}a{j}t{k}") | |
) | |
## Down Cross Attns | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q.register_forward_hook( | |
getActivation(dicts, f"cattnqd{i}a{j}t{k}") | |
) | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k.register_forward_hook( | |
getActivation(dicts, f"cattnkd{i}a{j}t{k}") | |
) | |
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v.register_forward_hook( | |
getActivation(dicts, f"cattnvd{i}a{j}t{k}") | |
) | |
for i in range(self.up_idx_start, self.up_idx_end): | |
for j in range(len(pipe.unet.up_blocks[i].attentions)): | |
for k in range( | |
len(pipe.unet.up_blocks[i].attentions[j].transformer_blocks) | |
): | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].ff.register_forward_hook(getActivation(dicts, f"u{i}a{j}t{k}")) | |
## Up Self Attns | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_q.register_forward_hook( | |
getActivation(dicts, f"sattnqu{i}a{j}t{k}") | |
) | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_k.register_forward_hook( | |
getActivation(dicts, f"sattnku{i}a{j}t{k}") | |
) | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn1.to_v.register_forward_hook( | |
getActivation(dicts, f"sattnvu{i}a{j}t{k}") | |
) | |
## Up Cross Attns | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_q.register_forward_hook( | |
getActivation(dicts, f"cattnqu{i}a{j}t{k}") | |
) | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_k.register_forward_hook( | |
getActivation(dicts, f"cattnku{i}a{j}t{k}") | |
) | |
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ | |
k | |
].attn2.to_v.register_forward_hook( | |
getActivation(dicts, f"cattnvu{i}a{j}t{k}") | |
) | |
def get_hidden_states(self, model, positive, negative, average: bool = True): | |
intermediate = {} | |
self.cast_hook(model, intermediate) | |
with torch.no_grad(): | |
_ = model(positive, negative_prompt=negative, num_inference_steps=25) | |
hidden = {} | |
for key in intermediate: | |
hidden_states = intermediate[key][0][-1] | |
if average: | |
# use average over sequence | |
hidden_states = hidden_states.sum(dim=0) / hidden_states.shape[0] | |
else: | |
# take last value | |
hidden_states = hidden_states[:-1] | |
hidden[key] = hidden_states.to(self.device) | |
del intermediate | |
gc.collect() | |
torch.cuda.empty_cache() | |
return hidden | |
def get_gate_params( | |
self, | |
experts, | |
positive, | |
negative, | |
): | |
gate_vects = {} | |
for i, expert in enumerate(tqdm.tqdm(experts, desc="Expert Prompts")): | |
expert.to(self.device) | |
expert.unet.to( | |
device=self.device, | |
dtype=self.torch_dtype, | |
memory_format=torch.channels_last, | |
) | |
hidden_states = self.get_hidden_states(expert, positive[i], negative[i]) | |
del expert | |
gc.collect() | |
torch.cuda.empty_cache() | |
for h in hidden_states: | |
if i == 0: | |
gate_vects[h] = [] | |
hidden_states[h] /= ( | |
hidden_states[h].norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8) | |
) | |
gate_vects[h].append(hidden_states[h]) | |
for h in hidden_states: | |
gate_vects[h] = torch.stack( | |
gate_vects[h], dim=0 | |
) # (num_expert, num_layer, hidden_size) | |
gate_vects[h].permute(1, 0) | |
return gate_vects | |