Viewgen / lora_inference.py
RingL's picture
updated view scripts
853a5cb
from diffusers import AutoPipelineForText2Image
import torch
# Load the base pipeline
base_model = "runwayml/stable-diffusion-v1-5"
# Define a function to create a pipeline with the appropriate LoRA weights
def create_pipeline_with_lora(view):
pipeline = AutoPipelineForText2Image.from_pretrained(base_model, torch_dtype=torch.float16).to("cuda")
if view == "front":
pipeline.load_lora_weights("/content/out_front/checkpoint-100", weight_name="pytorch_lora_weights.safetensors")
elif view == "back":
pipeline.load_lora_weights("/content/out_back/checkpoint-100", weight_name="pytorch_lora_weights.safetensors")
elif view == "side":
pipeline.load_lora_weights("/content/out_side/checkpoint-100", weight_name="pytorch_lora_weights.safetensors")
else:
raise ValueError("Unsupported view: {}".format(view))
return pipeline
# Define the prompt
prompt = "high quality traffic in street."
# Generate images for each view
views = ["front", "back", "side"]
for view in views:
pipeline = create_pipeline_with_lora(view)
view_prompt = f"<{view} view> {prompt}"
image = pipeline(view_prompt).images[0]
image.save(f"traffic_{view}.png")
print("Images for all views generated and saved successfully.")