Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import difflib | |
import json | |
import os | |
import diffusers | |
import numpy as np | |
import requests | |
import torch | |
import torch.nn.functional as F | |
import transformers | |
from diffusers import (AutoencoderKL, DiffusionPipeline, | |
FlowMatchEulerDiscreteScheduler, FluxPipeline, | |
FluxTransformer2DModel, SD3Transformer2DModel, | |
StableDiffusion3Pipeline) | |
from diffusers.callbacks import PipelineCallback | |
from torchao.quantization import int8_weight_only, quantize_ | |
from torchvision import transforms | |
from transformers import (AutoModelForCausalLM, AutoProcessor, CLIPTextModel, | |
CLIPTextModelWithProjection, T5EncoderModel) | |
def get_flux_pipeline( | |
model_id="black-forest-labs/FLUX.1-dev", | |
pipeline_class=FluxPipeline, | |
torch_dtype=torch.bfloat16, | |
quantize=False | |
): | |
############ Diffusion Transformer ############ | |
transformer = FluxTransformer2DModel.from_pretrained( | |
model_id, subfolder="transformer", torch_dtype=torch_dtype | |
) | |
############ Text Encoder ############ | |
text_encoder = CLIPTextModel.from_pretrained( | |
model_id, subfolder="text_encoder", torch_dtype=torch_dtype | |
) | |
############ Text Encoder 2 ############ | |
text_encoder_2 = T5EncoderModel.from_pretrained( | |
model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype | |
) | |
############ VAE ############ | |
vae = AutoencoderKL.from_pretrained( | |
model_id, subfolder="vae", torch_dtype=torch_dtype | |
) | |
if quantize: | |
quantize_(transformer, int8_weight_only()) | |
quantize_(text_encoder, int8_weight_only()) | |
quantize_(text_encoder_2, int8_weight_only()) | |
quantize_(vae, int8_weight_only()) | |
# Initialize the pipeline now. | |
pipe = pipeline_class.from_pretrained( | |
model_id, | |
transformer=transformer, | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
torch_dtype=torch_dtype | |
) | |
return pipe | |
def mask_decode(encoded_mask,image_shape=[512,512]): | |
length=image_shape[0]*image_shape[1] | |
mask_array=np.zeros((length,)) | |
for i in range(0,len(encoded_mask),2): | |
splice_len=min(encoded_mask[i+1],length-encoded_mask[i]) | |
for j in range(splice_len): | |
mask_array[encoded_mask[i]+j]=1 | |
mask_array=mask_array.reshape(image_shape[0], image_shape[1]) | |
# to avoid annotation errors in boundary | |
mask_array[0,:]=1 | |
mask_array[-1,:]=1 | |
mask_array[:,0]=1 | |
mask_array[:,-1]=1 | |
return mask_array | |
def mask_interpolate(mask, size=128): | |
mask = torch.tensor(mask) | |
mask = F.interpolate(mask[None, None, ...], size, mode='bicubic') | |
mask = mask.squeeze() | |
return mask | |
def get_blend_word_index(prompt, word, tokenizer): | |
input_ids = tokenizer(prompt).input_ids | |
blend_ids = tokenizer(word, add_special_tokens=False).input_ids | |
index = [] | |
for i, id in enumerate(input_ids): | |
# Ignore common token | |
if id < 100: | |
continue | |
if id in blend_ids: | |
index.append(i) | |
return index | |
def find_token_id_differences(prompt1, prompt2, tokenizer): | |
# Tokenize inputs and get input IDs | |
tokens1 = tokenizer.encode(prompt1, add_special_tokens=False) | |
tokens2 = tokenizer.encode(prompt2, add_special_tokens=False) | |
# Get sequence matcher output | |
seq_matcher = difflib.SequenceMatcher(None, tokens1, tokens2) | |
diff1_indices, diff1_ids = [], [] | |
diff2_indices, diff2_ids = [], [] | |
for opcode, a_start, a_end, b_start, b_end in seq_matcher.get_opcodes(): | |
if opcode in ['replace', 'delete']: | |
diff1_indices.extend(range(a_start, a_end)) | |
diff1_ids.extend(tokens1[a_start:a_end]) | |
if opcode in ['replace', 'insert']: | |
diff2_indices.extend(range(b_start, b_end)) | |
diff2_ids.extend(tokens2[b_start:b_end]) | |
return { | |
'prompt_1': {'index': diff1_indices, 'id': diff1_ids}, | |
'prompt_2': {'index': diff2_indices, 'id': diff2_ids} | |
} | |
def find_word_token_indices(prompt, word, tokenizer): | |
# Tokenize with offsets to track word positions | |
encoding = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False) | |
tokens = encoding.tokens() | |
offsets = encoding.offset_mapping # Start and end positions of tokens in the original text | |
word_indices = [] | |
# Normalize the word for comparison | |
word_tokens = tokenizer(word, add_special_tokens=False).tokens() | |
# Find matching token sequences | |
for i in range(len(tokens) - len(word_tokens) + 1): | |
if tokens[i : i + len(word_tokens)] == word_tokens: | |
word_indices.extend(range(i, i + len(word_tokens))) | |
return word_indices |