import os
import io
import base64
import requests
import tempfile
from typing import Optional, Dict, Any
from PIL import Image
import numpy as np
import gradio as gr
from huggingface_hub import InferenceClient
from utils import create_temp_media_url, compress_media_for_data_uri, validate_video_html
from config import HF_TOKEN
class MediaGenerator:
"""Handles generation of images, videos, and music"""
def __init__(self):
self.hf_client = None
if HF_TOKEN:
self.hf_client = InferenceClient(
provider="auto",
api_key=HF_TOKEN,
bill_to="huggingface"
)
def generate_image_with_qwen(self, prompt: str, image_index: int = 0,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate image using Qwen image model"""
try:
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
print(f"[ImageGen] Generating image with prompt: {prompt}")
# Generate image using Qwen/Qwen-Image model
image = self.hf_client.text_to_image(
prompt,
model="Qwen/Qwen-Image",
)
# Resize image to reduce size while maintaining quality
max_size = 1024
if image.width > max_size or image.height > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Convert to bytes
buffer = io.BytesIO()
image.convert('RGB').save(buffer, format='JPEG', quality=90, optimize=True)
image_bytes = buffer.getvalue()
# Create temporary URL
filename = f"generated_image_{image_index}.jpg"
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
return f''
except Exception as e:
print(f"Image generation error: {str(e)}")
return f"Error generating image: {str(e)}"
def generate_image_to_image(self, input_image_data, prompt: str,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate image using image-to-image with Qwen-Image-Edit"""
try:
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
print(f"[Image2Image] Processing with prompt: {prompt}")
# Normalize input image to bytes
pil_image = self._process_input_image(input_image_data)
# Resize input image to avoid request body size limits
max_input_size = 1024
if pil_image.width > max_input_size or pil_image.height > max_input_size:
pil_image.thumbnail((max_input_size, max_input_size), Image.Resampling.LANCZOS)
# Convert to bytes
buf = io.BytesIO()
pil_image.save(buf, format='JPEG', quality=85, optimize=True)
input_bytes = buf.getvalue()
# Call image-to-image
image = self.hf_client.image_to_image(
input_bytes,
prompt=prompt,
model="Qwen/Qwen-Image-Edit",
)
# Resize and optimize output
max_size = 1024
if image.width > max_size or image.height > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
out_buf = io.BytesIO()
image.convert('RGB').save(out_buf, format='JPEG', quality=90, optimize=True)
image_bytes = out_buf.getvalue()
# Create temporary URL
filename = "image_to_image_result.jpg"
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
return f'
'
except Exception as e:
print(f"Image-to-image generation error: {str(e)}")
return f"Error generating image (image-to-image): {str(e)}"
def generate_video_from_image(self, input_image_data, prompt: str,
session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate video from input image using Lightricks LTX-Video"""
try:
print("[Image2Video] Starting video generation")
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
# Process input image
pil_image = self._process_input_image(input_image_data)
print(f"[Image2Video] Input image size: {pil_image.size}")
# Compress image for API limits
input_bytes = self._compress_image_for_video(pil_image, max_size_mb=3.9)
# Check for image-to-video method
image_to_video_method = getattr(self.hf_client, "image_to_video", None)
if not callable(image_to_video_method):
return ("Error: Your huggingface_hub version does not support image_to_video. "
"Please upgrade with `pip install -U huggingface_hub`")
model_id = "Lightricks/LTX-Video-0.9.8-13B-distilled"
print(f"[Image2Video] Calling API with model: {model_id}")
video_bytes = image_to_video_method(
input_bytes,
prompt=prompt,
model=model_id,
)
print(f"[Image2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}")
# Create temporary URL
filename = "image_to_video_result.mp4"
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
video_html = self._create_video_html(temp_url)
if not validate_video_html(video_html):
return "Error: Generated video HTML is malformed"
print(f"[Image2Video] Successfully generated video: {temp_url}")
return video_html
except Exception as e:
print(f"[Image2Video] Error: {str(e)}")
return f"Error generating video (image-to-video): {str(e)}"
def generate_video_from_text(self, prompt: str, session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate video from text prompt using Wan-AI text-to-video model"""
try:
print("[Text2Video] Starting video generation")
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
# Check for text-to-video method
text_to_video_method = getattr(self.hf_client, "text_to_video", None)
if not callable(text_to_video_method):
return ("Error: Your huggingface_hub version does not support text_to_video. "
"Please upgrade with `pip install -U huggingface_hub`")
model_id = "Wan-AI/Wan2.2-T2V-A14B"
prompt_str = (prompt or "").strip()
print(f"[Text2Video] Using model: {model_id}, prompt length: {len(prompt_str)}")
video_bytes = text_to_video_method(
prompt_str,
model=model_id,
)
print(f"[Text2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}")
# Create temporary URL
filename = "text_to_video_result.mp4"
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
video_html = self._create_video_html(temp_url)
if not validate_video_html(video_html):
return "Error: Generated video HTML is malformed"
print(f"[Text2Video] Successfully generated video: {temp_url}")
return video_html
except Exception as e:
print(f"[Text2Video] Error: {str(e)}")
return f"Error generating video (text-to-video): {str(e)}"
def generate_music_from_text(self, prompt: str, music_length_ms: int = 30000,
session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate music using ElevenLabs Music API"""
try:
api_key = os.getenv('ELEVENLABS_API_KEY')
if not api_key:
return "Error: ELEVENLABS_API_KEY environment variable is not set."
print(f"[MusicGen] Generating music: {prompt}")
headers = {
'Content-Type': 'application/json',
'xi-api-key': api_key,
}
payload = {
'prompt': prompt or 'Epic orchestral theme with soaring strings and powerful brass',
'music_length_ms': int(music_length_ms) if music_length_ms else 30000,
}
resp = requests.post(
'https://api.elevenlabs.io/v1/music/compose',
headers=headers,
json=payload,
timeout=60
)
try:
resp.raise_for_status()
except Exception as e:
error_text = getattr(e, 'response', resp).text if hasattr(e, 'response') else resp.text
return f"Error generating music: {error_text}"
# Create temporary URL
filename = "generated_music.mp3"
temp_url = self._upload_media_to_hf(resp.content, filename, "audio", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
audio_html = self._create_audio_html(temp_url)
print(f"[MusicGen] Successfully generated music: {temp_url}")
return audio_html
except Exception as e:
print(f"[MusicGen] Error: {str(e)}")
return f"Error generating music: {str(e)}"
def _process_input_image(self, input_image_data) -> Image.Image:
"""Convert various image formats to PIL Image"""
if hasattr(input_image_data, 'read'):
raw = input_image_data.read()
pil_image = Image.open(io.BytesIO(raw))
elif hasattr(input_image_data, 'mode') and hasattr(input_image_data, 'size'):
pil_image = input_image_data
elif isinstance(input_image_data, np.ndarray):
pil_image = Image.fromarray(input_image_data)
elif isinstance(input_image_data, (bytes, bytearray)):
pil_image = Image.open(io.BytesIO(input_image_data))
else:
pil_image = Image.open(io.BytesIO(bytes(input_image_data)))
# Ensure RGB
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
return pil_image
def _compress_image_for_video(self, pil_image: Image.Image, max_size_mb: float = 3.9) -> bytes:
"""Compress image for video generation API limits"""
MAX_BYTES = int(max_size_mb * 1024 * 1024)
max_dim = 1024
quality = 90
def encode_current(pil: Image.Image, q: int) -> bytes:
tmp = io.BytesIO()
pil.save(tmp, format='JPEG', quality=q, optimize=True)
return tmp.getvalue()
# Downscale while too large
while max(pil_image.size) > max_dim:
ratio = max_dim / float(max(pil_image.size))
new_size = (max(1, int(pil_image.size[0] * ratio)), max(1, int(pil_image.size[1] * ratio)))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
encoded = encode_current(pil_image, quality)
# Reduce quality or dimensions if still too large
while len(encoded) > MAX_BYTES and (quality > 40 or max(pil_image.size) > 640):
if quality > 40:
quality -= 10
else:
new_w = max(1, int(pil_image.size[0] * 0.85))
new_h = max(1, int(pil_image.size[1] * 0.85))
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
encoded = encode_current(pil_image, quality)
return encoded
def _upload_media_to_hf(self, media_bytes: bytes, filename: str, media_type: str,
token: Optional[gr.OAuthToken] = None, use_temp: bool = True) -> str:
"""Upload media to HF or create temporary file"""
if use_temp:
return create_temp_media_url(media_bytes, filename, media_type)
# HF upload logic would go here for permanent URLs
# For now, always use temp files
return create_temp_media_url(media_bytes, filename, media_type)
def _create_video_html(self, video_url: str) -> str:
"""Create HTML video element"""
return f''''''
def _create_audio_html(self, audio_url: str) -> str:
"""Create HTML audio player"""
return f'''