import os import io import base64 import requests import tempfile from typing import Optional, Dict, Any from PIL import Image import numpy as np import gradio as gr from huggingface_hub import InferenceClient from utils import create_temp_media_url, compress_media_for_data_uri, validate_video_html from config import HF_TOKEN class MediaGenerator: """Handles generation of images, videos, and music""" def __init__(self): self.hf_client = None if HF_TOKEN: self.hf_client = InferenceClient( provider="auto", api_key=HF_TOKEN, bill_to="huggingface" ) def generate_image_with_qwen(self, prompt: str, image_index: int = 0, token: Optional[gr.OAuthToken] = None) -> str: """Generate image using Qwen image model""" try: if not self.hf_client: return "Error: HF_TOKEN environment variable is not set." print(f"[ImageGen] Generating image with prompt: {prompt}") # Generate image using Qwen/Qwen-Image model image = self.hf_client.text_to_image( prompt, model="Qwen/Qwen-Image", ) # Resize image to reduce size while maintaining quality max_size = 1024 if image.width > max_size or image.height > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Convert to bytes buffer = io.BytesIO() image.convert('RGB').save(buffer, format='JPEG', quality=90, optimize=True) image_bytes = buffer.getvalue() # Create temporary URL filename = f"generated_image_{image_index}.jpg" temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True) if temp_url.startswith("Error"): return temp_url return f'{prompt}' except Exception as e: print(f"Image generation error: {str(e)}") return f"Error generating image: {str(e)}" def generate_image_to_image(self, input_image_data, prompt: str, token: Optional[gr.OAuthToken] = None) -> str: """Generate image using image-to-image with Qwen-Image-Edit""" try: if not self.hf_client: return "Error: HF_TOKEN environment variable is not set." print(f"[Image2Image] Processing with prompt: {prompt}") # Normalize input image to bytes pil_image = self._process_input_image(input_image_data) # Resize input image to avoid request body size limits max_input_size = 1024 if pil_image.width > max_input_size or pil_image.height > max_input_size: pil_image.thumbnail((max_input_size, max_input_size), Image.Resampling.LANCZOS) # Convert to bytes buf = io.BytesIO() pil_image.save(buf, format='JPEG', quality=85, optimize=True) input_bytes = buf.getvalue() # Call image-to-image image = self.hf_client.image_to_image( input_bytes, prompt=prompt, model="Qwen/Qwen-Image-Edit", ) # Resize and optimize output max_size = 1024 if image.width > max_size or image.height > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) out_buf = io.BytesIO() image.convert('RGB').save(out_buf, format='JPEG', quality=90, optimize=True) image_bytes = out_buf.getvalue() # Create temporary URL filename = "image_to_image_result.jpg" temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True) if temp_url.startswith("Error"): return temp_url return f'{prompt}' except Exception as e: print(f"Image-to-image generation error: {str(e)}") return f"Error generating image (image-to-image): {str(e)}" def generate_video_from_image(self, input_image_data, prompt: str, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: """Generate video from input image using Lightricks LTX-Video""" try: print("[Image2Video] Starting video generation") if not self.hf_client: return "Error: HF_TOKEN environment variable is not set." # Process input image pil_image = self._process_input_image(input_image_data) print(f"[Image2Video] Input image size: {pil_image.size}") # Compress image for API limits input_bytes = self._compress_image_for_video(pil_image, max_size_mb=3.9) # Check for image-to-video method image_to_video_method = getattr(self.hf_client, "image_to_video", None) if not callable(image_to_video_method): return ("Error: Your huggingface_hub version does not support image_to_video. " "Please upgrade with `pip install -U huggingface_hub`") model_id = "Lightricks/LTX-Video-0.9.8-13B-distilled" print(f"[Image2Video] Calling API with model: {model_id}") video_bytes = image_to_video_method( input_bytes, prompt=prompt, model=model_id, ) print(f"[Image2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}") # Create temporary URL filename = "image_to_video_result.mp4" temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True) if temp_url.startswith("Error"): return temp_url video_html = self._create_video_html(temp_url) if not validate_video_html(video_html): return "Error: Generated video HTML is malformed" print(f"[Image2Video] Successfully generated video: {temp_url}") return video_html except Exception as e: print(f"[Image2Video] Error: {str(e)}") return f"Error generating video (image-to-video): {str(e)}" def generate_video_from_text(self, prompt: str, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: """Generate video from text prompt using Wan-AI text-to-video model""" try: print("[Text2Video] Starting video generation") if not self.hf_client: return "Error: HF_TOKEN environment variable is not set." # Check for text-to-video method text_to_video_method = getattr(self.hf_client, "text_to_video", None) if not callable(text_to_video_method): return ("Error: Your huggingface_hub version does not support text_to_video. " "Please upgrade with `pip install -U huggingface_hub`") model_id = "Wan-AI/Wan2.2-T2V-A14B" prompt_str = (prompt or "").strip() print(f"[Text2Video] Using model: {model_id}, prompt length: {len(prompt_str)}") video_bytes = text_to_video_method( prompt_str, model=model_id, ) print(f"[Text2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}") # Create temporary URL filename = "text_to_video_result.mp4" temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True) if temp_url.startswith("Error"): return temp_url video_html = self._create_video_html(temp_url) if not validate_video_html(video_html): return "Error: Generated video HTML is malformed" print(f"[Text2Video] Successfully generated video: {temp_url}") return video_html except Exception as e: print(f"[Text2Video] Error: {str(e)}") return f"Error generating video (text-to-video): {str(e)}" def generate_music_from_text(self, prompt: str, music_length_ms: int = 30000, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: """Generate music using ElevenLabs Music API""" try: api_key = os.getenv('ELEVENLABS_API_KEY') if not api_key: return "Error: ELEVENLABS_API_KEY environment variable is not set." print(f"[MusicGen] Generating music: {prompt}") headers = { 'Content-Type': 'application/json', 'xi-api-key': api_key, } payload = { 'prompt': prompt or 'Epic orchestral theme with soaring strings and powerful brass', 'music_length_ms': int(music_length_ms) if music_length_ms else 30000, } resp = requests.post( 'https://api.elevenlabs.io/v1/music/compose', headers=headers, json=payload, timeout=60 ) try: resp.raise_for_status() except Exception as e: error_text = getattr(e, 'response', resp).text if hasattr(e, 'response') else resp.text return f"Error generating music: {error_text}" # Create temporary URL filename = "generated_music.mp3" temp_url = self._upload_media_to_hf(resp.content, filename, "audio", token, use_temp=True) if temp_url.startswith("Error"): return temp_url audio_html = self._create_audio_html(temp_url) print(f"[MusicGen] Successfully generated music: {temp_url}") return audio_html except Exception as e: print(f"[MusicGen] Error: {str(e)}") return f"Error generating music: {str(e)}" def _process_input_image(self, input_image_data) -> Image.Image: """Convert various image formats to PIL Image""" if hasattr(input_image_data, 'read'): raw = input_image_data.read() pil_image = Image.open(io.BytesIO(raw)) elif hasattr(input_image_data, 'mode') and hasattr(input_image_data, 'size'): pil_image = input_image_data elif isinstance(input_image_data, np.ndarray): pil_image = Image.fromarray(input_image_data) elif isinstance(input_image_data, (bytes, bytearray)): pil_image = Image.open(io.BytesIO(input_image_data)) else: pil_image = Image.open(io.BytesIO(bytes(input_image_data))) # Ensure RGB if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') return pil_image def _compress_image_for_video(self, pil_image: Image.Image, max_size_mb: float = 3.9) -> bytes: """Compress image for video generation API limits""" MAX_BYTES = int(max_size_mb * 1024 * 1024) max_dim = 1024 quality = 90 def encode_current(pil: Image.Image, q: int) -> bytes: tmp = io.BytesIO() pil.save(tmp, format='JPEG', quality=q, optimize=True) return tmp.getvalue() # Downscale while too large while max(pil_image.size) > max_dim: ratio = max_dim / float(max(pil_image.size)) new_size = (max(1, int(pil_image.size[0] * ratio)), max(1, int(pil_image.size[1] * ratio))) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) encoded = encode_current(pil_image, quality) # Reduce quality or dimensions if still too large while len(encoded) > MAX_BYTES and (quality > 40 or max(pil_image.size) > 640): if quality > 40: quality -= 10 else: new_w = max(1, int(pil_image.size[0] * 0.85)) new_h = max(1, int(pil_image.size[1] * 0.85)) pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS) encoded = encode_current(pil_image, quality) return encoded def _upload_media_to_hf(self, media_bytes: bytes, filename: str, media_type: str, token: Optional[gr.OAuthToken] = None, use_temp: bool = True) -> str: """Upload media to HF or create temporary file""" if use_temp: return create_temp_media_url(media_bytes, filename, media_type) # HF upload logic would go here for permanent URLs # For now, always use temp files return create_temp_media_url(media_bytes, filename, media_type) def _create_video_html(self, video_url: str) -> str: """Create HTML video element""" return f'''''' def _create_audio_html(self, audio_url: str) -> str: """Create HTML audio player""" return f'''
🎵 Generated music
''' # Global media generator instance media_generator = MediaGenerator() # Export main functions def generate_image_with_qwen(prompt: str, image_index: int = 0, token: Optional[gr.OAuthToken] = None) -> str: return media_generator.generate_image_with_qwen(prompt, image_index, token) def generate_image_to_image(input_image_data, prompt: str, token: Optional[gr.OAuthToken] = None) -> str: return media_generator.generate_image_to_image(input_image_data, prompt, token) def generate_video_from_image(input_image_data, prompt: str, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: return media_generator.generate_video_from_image(input_image_data, prompt, session_id, token) def generate_video_from_text(prompt: str, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: return media_generator.generate_video_from_text(prompt, session_id, token) def generate_music_from_text(prompt: str, music_length_ms: int = 30000, session_id: Optional[str] = None, token: Optional[gr.OAuthToken] = None) -> str: return media_generator.generate_music_from_text(prompt, music_length_ms, session_id, token)