Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
# Text-to-Image generation models from Hugging Face community. | |
import os | |
from abc import ABC, abstractmethod | |
import torch | |
from diffusers import ( | |
ChromaPipeline, | |
Cosmos2TextToImagePipeline, | |
DPMSolverMultistepScheduler, | |
FluxPipeline, | |
KolorsPipeline, | |
StableDiffusion3Pipeline, | |
) | |
from diffusers.quantizers import PipelineQuantizationConfig | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, SiglipProcessor | |
__all__ = [ | |
"build_hf_image_pipeline", | |
] | |
class BasePipelineLoader(ABC): | |
def __init__(self, device="cuda"): | |
self.device = device | |
def load(self): | |
pass | |
class BasePipelineRunner(ABC): | |
def __init__(self, pipe): | |
self.pipe = pipe | |
def run(self, prompt: str, **kwargs) -> Image.Image: | |
pass | |
# ===== SD3.5-medium ===== | |
class SD35Loader(BasePipelineLoader): | |
def load(self): | |
pipe = StableDiffusion3Pipeline.from_pretrained( | |
"stabilityai/stable-diffusion-3.5-medium", | |
torch_dtype=torch.float16, | |
) | |
pipe = pipe.to(self.device) | |
pipe.enable_model_cpu_offload() | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_attention_slicing() | |
return pipe | |
class SD35Runner(BasePipelineRunner): | |
def run(self, prompt: str, **kwargs) -> Image.Image: | |
return self.pipe(prompt=prompt, **kwargs).images | |
# ===== Cosmos2 ===== | |
class CosmosLoader(BasePipelineLoader): | |
def __init__( | |
self, | |
model_id="nvidia/Cosmos-Predict2-2B-Text2Image", | |
local_dir="weights/cosmos2", | |
device="cuda", | |
): | |
super().__init__(device) | |
self.model_id = model_id | |
self.local_dir = local_dir | |
def _patch(self): | |
def patch_model(cls): | |
orig = cls.from_pretrained | |
def new(*args, **kwargs): | |
kwargs.setdefault("attn_implementation", "flash_attention_2") | |
kwargs.setdefault("torch_dtype", torch.bfloat16) | |
return orig(*args, **kwargs) | |
cls.from_pretrained = new | |
def patch_processor(cls): | |
orig = cls.from_pretrained | |
def new(*args, **kwargs): | |
kwargs.setdefault("use_fast", True) | |
return orig(*args, **kwargs) | |
cls.from_pretrained = new | |
patch_model(AutoModelForCausalLM) | |
patch_processor(SiglipProcessor) | |
def load(self): | |
self._patch() | |
snapshot_download( | |
repo_id=self.model_id, | |
local_dir=self.local_dir, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
) | |
config = PipelineQuantizationConfig( | |
quant_backend="bitsandbytes_4bit", | |
quant_kwargs={ | |
"load_in_4bit": True, | |
"bnb_4bit_quant_type": "nf4", | |
"bnb_4bit_compute_dtype": torch.bfloat16, | |
"bnb_4bit_use_double_quant": True, | |
}, | |
components_to_quantize=["text_encoder", "transformer", "unet"], | |
) | |
pipe = Cosmos2TextToImagePipeline.from_pretrained( | |
self.model_id, | |
torch_dtype=torch.bfloat16, | |
quantization_config=config, | |
use_safetensors=True, | |
safety_checker=None, | |
requires_safety_checker=False, | |
).to(self.device) | |
return pipe | |
class CosmosRunner(BasePipelineRunner): | |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: | |
return self.pipe( | |
prompt=prompt, negative_prompt=negative_prompt, **kwargs | |
).images | |
# ===== Kolors ===== | |
class KolorsLoader(BasePipelineLoader): | |
def load(self): | |
pipe = KolorsPipeline.from_pretrained( | |
"Kwai-Kolors/Kolors-diffusers", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to(self.device) | |
pipe.enable_model_cpu_offload() | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True | |
) | |
return pipe | |
class KolorsRunner(BasePipelineRunner): | |
def run(self, prompt: str, **kwargs) -> Image.Image: | |
return self.pipe(prompt=prompt, **kwargs).images | |
# ===== Flux ===== | |
class FluxLoader(BasePipelineLoader): | |
def load(self): | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_model_cpu_offload() | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_attention_slicing() | |
return pipe.to(self.device) | |
class FluxRunner(BasePipelineRunner): | |
def run(self, prompt: str, **kwargs) -> Image.Image: | |
return self.pipe(prompt=prompt, **kwargs).images | |
# ===== Chroma ===== | |
class ChromaLoader(BasePipelineLoader): | |
def load(self): | |
return ChromaPipeline.from_pretrained( | |
"lodestones/Chroma", torch_dtype=torch.bfloat16 | |
).to(self.device) | |
class ChromaRunner(BasePipelineRunner): | |
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: | |
return self.pipe( | |
prompt=prompt, negative_prompt=negative_prompt, **kwargs | |
).images | |
PIPELINE_REGISTRY = { | |
"sd35": (SD35Loader, SD35Runner), | |
"cosmos": (CosmosLoader, CosmosRunner), | |
"kolors": (KolorsLoader, KolorsRunner), | |
"flux": (FluxLoader, FluxRunner), | |
"chroma": (ChromaLoader, ChromaRunner), | |
} | |
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: | |
if name not in PIPELINE_REGISTRY: | |
raise ValueError(f"Unsupported model: {name}") | |
loader_cls, runner_cls = PIPELINE_REGISTRY[name] | |
pipe = loader_cls(device=device).load() | |
return runner_cls(pipe) | |
if __name__ == "__main__": | |
model_name = "sd35" | |
runner = build_hf_image_pipeline(model_name) | |
# NOTE: Just for pipeline testing, generation quality at low resolution is poor. | |
images = runner.run( | |
prompt="A robot holding a sign that says 'Hello'", | |
height=512, | |
width=512, | |
num_inference_steps=10, | |
guidance_scale=6, | |
num_images_per_prompt=1, | |
) | |
for i, img in enumerate(images): | |
img.save(f"image_{model_name}_{i}.jpg") | |