# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # Text-to-Image generation models from Hugging Face community. import os from abc import ABC, abstractmethod import torch from diffusers import ( ChromaPipeline, Cosmos2TextToImagePipeline, DPMSolverMultistepScheduler, FluxPipeline, KolorsPipeline, StableDiffusion3Pipeline, ) from diffusers.quantizers import PipelineQuantizationConfig from huggingface_hub import snapshot_download from PIL import Image from transformers import AutoModelForCausalLM, SiglipProcessor __all__ = [ "build_hf_image_pipeline", ] class BasePipelineLoader(ABC): def __init__(self, device="cuda"): self.device = device @abstractmethod def load(self): pass class BasePipelineRunner(ABC): def __init__(self, pipe): self.pipe = pipe @abstractmethod def run(self, prompt: str, **kwargs) -> Image.Image: pass # ===== SD3.5-medium ===== class SD35Loader(BasePipelineLoader): def load(self): pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, ) pipe = pipe.to(self.device) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() pipe.enable_attention_slicing() return pipe class SD35Runner(BasePipelineRunner): def run(self, prompt: str, **kwargs) -> Image.Image: return self.pipe(prompt=prompt, **kwargs).images # ===== Cosmos2 ===== class CosmosLoader(BasePipelineLoader): def __init__( self, model_id="nvidia/Cosmos-Predict2-2B-Text2Image", local_dir="weights/cosmos2", device="cuda", ): super().__init__(device) self.model_id = model_id self.local_dir = local_dir def _patch(self): def patch_model(cls): orig = cls.from_pretrained def new(*args, **kwargs): kwargs.setdefault("attn_implementation", "flash_attention_2") kwargs.setdefault("torch_dtype", torch.bfloat16) return orig(*args, **kwargs) cls.from_pretrained = new def patch_processor(cls): orig = cls.from_pretrained def new(*args, **kwargs): kwargs.setdefault("use_fast", True) return orig(*args, **kwargs) cls.from_pretrained = new patch_model(AutoModelForCausalLM) patch_processor(SiglipProcessor) def load(self): self._patch() snapshot_download( repo_id=self.model_id, local_dir=self.local_dir, local_dir_use_symlinks=False, resume_download=True, ) config = PipelineQuantizationConfig( quant_backend="bitsandbytes_4bit", quant_kwargs={ "load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_use_double_quant": True, }, components_to_quantize=["text_encoder", "transformer", "unet"], ) pipe = Cosmos2TextToImagePipeline.from_pretrained( self.model_id, torch_dtype=torch.bfloat16, quantization_config=config, use_safetensors=True, safety_checker=None, requires_safety_checker=False, ).to(self.device) return pipe class CosmosRunner(BasePipelineRunner): def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images # ===== Kolors ===== class KolorsLoader(BasePipelineLoader): def load(self): pipe = KolorsPipeline.from_pretrained( "Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, variant="fp16", ).to(self.device) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() pipe.scheduler = DPMSolverMultistepScheduler.from_config( pipe.scheduler.config, use_karras_sigmas=True ) return pipe class KolorsRunner(BasePipelineRunner): def run(self, prompt: str, **kwargs) -> Image.Image: return self.pipe(prompt=prompt, **kwargs).images # ===== Flux ===== class FluxLoader(BasePipelineLoader): def load(self): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() pipe.enable_attention_slicing() return pipe.to(self.device) class FluxRunner(BasePipelineRunner): def run(self, prompt: str, **kwargs) -> Image.Image: return self.pipe(prompt=prompt, **kwargs).images # ===== Chroma ===== class ChromaLoader(BasePipelineLoader): def load(self): return ChromaPipeline.from_pretrained( "lodestones/Chroma", torch_dtype=torch.bfloat16 ).to(self.device) class ChromaRunner(BasePipelineRunner): def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images PIPELINE_REGISTRY = { "sd35": (SD35Loader, SD35Runner), "cosmos": (CosmosLoader, CosmosRunner), "kolors": (KolorsLoader, KolorsRunner), "flux": (FluxLoader, FluxRunner), "chroma": (ChromaLoader, ChromaRunner), } def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: if name not in PIPELINE_REGISTRY: raise ValueError(f"Unsupported model: {name}") loader_cls, runner_cls = PIPELINE_REGISTRY[name] pipe = loader_cls(device=device).load() return runner_cls(pipe) if __name__ == "__main__": model_name = "sd35" runner = build_hf_image_pipeline(model_name) # NOTE: Just for pipeline testing, generation quality at low resolution is poor. images = runner.run( prompt="A robot holding a sign that says 'Hello'", height=512, width=512, num_inference_steps=10, guidance_scale=6, num_images_per_prompt=1, ) for i, img in enumerate(images): img.save(f"image_{model_name}_{i}.jpg")