from typing import List from diffusers.modular_pipelines import ( PipelineState, ModularPipelineBlocks, InputParam, OutputParam, ) import google.generativeai as genai import os SYSTEM_PROMPT = ( "You are an expert image generation assistant. " "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. " "Ensure rich colors, depth, realistic lighting, and an imaginative composition. " "Avoid vague terms — be specific about style, perspective, and mood. " "Try to keep the output under 512 tokens. " "Please don't return any prefix or suffix tokens, just the expanded user description." ) class GeminiPromptExpander(ModularPipelineBlocks): model_name = "flux" def __init__(self, model_id="gemini-2.5-flash-lite", system_prompt=SYSTEM_PROMPT): super().__init__() api_key = os.getenv("GOOGLE_API_KEY") if api_key is None: raise ValueError("Must provide an API key for Gemini through the `GOOGLE_API_KEY` env variable.") genai.configure(api_key=api_key) self.model = genai.GenerativeModel(model_name=model_id, system_instruction=system_prompt) @property def expected_components(self): return [] @property def inputs(self) -> List[InputParam]: return [ InputParam( "prompt", type_hint=str, required=True, description="Prompt to use", ) ] @property def intermediate_inputs(self) -> List[InputParam]: return [] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "prompt", type_hint=str, description="Expanded prompt by the LLM", ), OutputParam( "old_prompt", type_hint=str, description="Old prompt provided by the user", ) ] def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) old_prompt = block_state.prompt print(f"Actual prompt: {old_prompt}") block_state.prompt = self.model.generate_content(old_prompt).text block_state.old_prompt = old_prompt print(f"{block_state.prompt=}") self.set_block_state(state, block_state) return components, state