jamesoncrate's picture
move pipeline to gpu
3eb8145
raw
history blame
5.43 kB
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile
import os
# Global variable to store the text pipeline
text_pipe = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
"""Load the T5 text encoder model"""
global text_pipe
if text_pipe is None:
print("Loading T5 text encoder...")
# Get token from environment
token = os.getenv("HF_TOKEN")
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
load_in_8bit=True,
variant="8bit",
token=token
)
text_pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=text_encoder,
unet=None,
token=token,
)
text_pipe = text_pipe.to(device)
print("Model loaded successfully!")
return text_pipe
@spaces.GPU
def generate_embeddings(prompts_text):
"""
Generate embeddings from text prompts
Args:
prompts_text: String with one prompt per line
Returns:
Path to the saved .pth file and a status message
"""
try:
# Load model if not already loaded
pipe = load_model()
# Note: 8-bit models are already on the correct device, no need to move them
# Parse prompts (one per line)
prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
if not prompts:
return None, "Error: Please enter at least one prompt"
# Add empty string for CFG (Classifier Free Guidance)
if '' not in prompts:
prompts.append('')
# Generate embeddings
print(f"Generating embeddings for {len(prompts)} prompts...")
prompt_embeds_list = []
for prompt in prompts:
embeds = pipe.encode_prompt(prompt)
prompt_embeds_list.append(embeds)
# Extract positive prompt embeddings
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
# Move embeddings to CPU before saving
prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds]
# Create dictionary
prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu))
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
torch.save(prompt_embeds_dict, temp_file.name)
temp_file.close()
status_msg = f"βœ… Successfully generated embeddings for {len(prompts)} prompts!\n"
status_msg += "Each embedding has shape: [1, 77, 4096]\n"
status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts])
return temp_file.name, status_msg
except Exception as e:
import traceback
error_details = traceback.format_exc()
return None, f"❌ Error: {str(e)}\n\nDetails:\n{error_details}"
# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
gr.Markdown("""
# πŸ”€ CS180 HW5: T5 Text Encoder Embeddings Generator
This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
### How to use:
1. Enter your prompts in the text box (one prompt per line)
2. Click "Generate Embeddings"
3. Download the generated `.pth` file containing the embeddings
### About the embeddings:
- Each embedding has shape: `[1, 77, 4096]`
- `77` = max sequence length
- `4096` = embedding dimension of the T5 encoder
- An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
""")
with gr.Row():
with gr.Column():
prompts_input = gr.Textbox(
label="Enter Prompts (one per line)",
placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
lines=15,
value="""an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
)
generate_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary", size="lg")
with gr.Column():
status_output = gr.Textbox(
label="Status",
lines=10,
interactive=False
)
file_output = gr.File(
label="Download Embeddings (.pth file)"
)
generate_btn.click(
fn=generate_embeddings,
inputs=[prompts_input],
outputs=[file_output, status_output]
)
gr.Markdown("""
### πŸ“ Note:
- The first run may take a while as the model needs to download (~8GB)
- Subsequent runs will be faster
- The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
""")
if __name__ == "__main__":
demo.launch()