ReFlex / src /utils.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import base64
import difflib
import json
import os
import diffusers
import numpy as np
import requests
import torch
import torch.nn.functional as F
import transformers
from diffusers import (AutoencoderKL, DiffusionPipeline,
FlowMatchEulerDiscreteScheduler, FluxPipeline,
FluxTransformer2DModel, SD3Transformer2DModel,
StableDiffusion3Pipeline)
from diffusers.callbacks import PipelineCallback
from torchao.quantization import int8_weight_only, quantize_
from torchvision import transforms
from transformers import (AutoModelForCausalLM, AutoProcessor, CLIPTextModel,
CLIPTextModelWithProjection, T5EncoderModel)
def get_flux_pipeline(
model_id="black-forest-labs/FLUX.1-dev",
pipeline_class=FluxPipeline,
torch_dtype=torch.bfloat16,
quantize=False
):
############ Diffusion Transformer ############
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch_dtype
)
############ Text Encoder ############
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=torch_dtype
)
############ Text Encoder 2 ############
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype
)
############ VAE ############
vae = AutoencoderKL.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch_dtype
)
if quantize:
quantize_(transformer, int8_weight_only())
quantize_(text_encoder, int8_weight_only())
quantize_(text_encoder_2, int8_weight_only())
quantize_(vae, int8_weight_only())
# Initialize the pipeline now.
pipe = pipeline_class.from_pretrained(
model_id,
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
torch_dtype=torch_dtype
)
return pipe
def mask_decode(encoded_mask,image_shape=[512,512]):
length=image_shape[0]*image_shape[1]
mask_array=np.zeros((length,))
for i in range(0,len(encoded_mask),2):
splice_len=min(encoded_mask[i+1],length-encoded_mask[i])
for j in range(splice_len):
mask_array[encoded_mask[i]+j]=1
mask_array=mask_array.reshape(image_shape[0], image_shape[1])
# to avoid annotation errors in boundary
mask_array[0,:]=1
mask_array[-1,:]=1
mask_array[:,0]=1
mask_array[:,-1]=1
return mask_array
def mask_interpolate(mask, size=128):
mask = torch.tensor(mask)
mask = F.interpolate(mask[None, None, ...], size, mode='bicubic')
mask = mask.squeeze()
return mask
def get_blend_word_index(prompt, word, tokenizer):
input_ids = tokenizer(prompt).input_ids
blend_ids = tokenizer(word, add_special_tokens=False).input_ids
index = []
for i, id in enumerate(input_ids):
# Ignore common token
if id < 100:
continue
if id in blend_ids:
index.append(i)
return index
def find_token_id_differences(prompt1, prompt2, tokenizer):
# Tokenize inputs and get input IDs
tokens1 = tokenizer.encode(prompt1, add_special_tokens=False)
tokens2 = tokenizer.encode(prompt2, add_special_tokens=False)
# Get sequence matcher output
seq_matcher = difflib.SequenceMatcher(None, tokens1, tokens2)
diff1_indices, diff1_ids = [], []
diff2_indices, diff2_ids = [], []
for opcode, a_start, a_end, b_start, b_end in seq_matcher.get_opcodes():
if opcode in ['replace', 'delete']:
diff1_indices.extend(range(a_start, a_end))
diff1_ids.extend(tokens1[a_start:a_end])
if opcode in ['replace', 'insert']:
diff2_indices.extend(range(b_start, b_end))
diff2_ids.extend(tokens2[b_start:b_end])
return {
'prompt_1': {'index': diff1_indices, 'id': diff1_ids},
'prompt_2': {'index': diff2_indices, 'id': diff2_ids}
}
def find_word_token_indices(prompt, word, tokenizer):
# Tokenize with offsets to track word positions
encoding = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=False)
tokens = encoding.tokens()
offsets = encoding.offset_mapping # Start and end positions of tokens in the original text
word_indices = []
# Normalize the word for comparison
word_tokens = tokenizer(word, add_special_tokens=False).tokens()
# Find matching token sequences
for i in range(len(tokens) - len(word_tokens) + 1):
if tokens[i : i + len(word_tokens)] == word_tokens:
word_indices.extend(range(i, i + len(word_tokens)))
return word_indices