yeye / media_generation.py
mgbam's picture
Rename media_processing.py to media_generation.py
916c98e verified
import os
import io
import base64
import requests
import tempfile
from typing import Optional, Dict, Any
from PIL import Image
import numpy as np
import gradio as gr
from huggingface_hub import InferenceClient
from utils import create_temp_media_url, compress_media_for_data_uri, validate_video_html
from config import HF_TOKEN
class MediaGenerator:
"""Handles generation of images, videos, and music"""
def __init__(self):
self.hf_client = None
if HF_TOKEN:
self.hf_client = InferenceClient(
provider="auto",
api_key=HF_TOKEN,
bill_to="huggingface"
)
def generate_image_with_qwen(self, prompt: str, image_index: int = 0,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate image using Qwen image model"""
try:
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
print(f"[ImageGen] Generating image with prompt: {prompt}")
# Generate image using Qwen/Qwen-Image model
image = self.hf_client.text_to_image(
prompt,
model="Qwen/Qwen-Image",
)
# Resize image to reduce size while maintaining quality
max_size = 1024
if image.width > max_size or image.height > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Convert to bytes
buffer = io.BytesIO()
image.convert('RGB').save(buffer, format='JPEG', quality=90, optimize=True)
image_bytes = buffer.getvalue()
# Create temporary URL
filename = f"generated_image_{image_index}.jpg"
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
return f'<img src="{temp_url}" alt="{prompt}" style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0;" loading="lazy" />'
except Exception as e:
print(f"Image generation error: {str(e)}")
return f"Error generating image: {str(e)}"
def generate_image_to_image(self, input_image_data, prompt: str,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate image using image-to-image with Qwen-Image-Edit"""
try:
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
print(f"[Image2Image] Processing with prompt: {prompt}")
# Normalize input image to bytes
pil_image = self._process_input_image(input_image_data)
# Resize input image to avoid request body size limits
max_input_size = 1024
if pil_image.width > max_input_size or pil_image.height > max_input_size:
pil_image.thumbnail((max_input_size, max_input_size), Image.Resampling.LANCZOS)
# Convert to bytes
buf = io.BytesIO()
pil_image.save(buf, format='JPEG', quality=85, optimize=True)
input_bytes = buf.getvalue()
# Call image-to-image
image = self.hf_client.image_to_image(
input_bytes,
prompt=prompt,
model="Qwen/Qwen-Image-Edit",
)
# Resize and optimize output
max_size = 1024
if image.width > max_size or image.height > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
out_buf = io.BytesIO()
image.convert('RGB').save(out_buf, format='JPEG', quality=90, optimize=True)
image_bytes = out_buf.getvalue()
# Create temporary URL
filename = "image_to_image_result.jpg"
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
return f'<img src="{temp_url}" alt="{prompt}" style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0;" loading="lazy" />'
except Exception as e:
print(f"Image-to-image generation error: {str(e)}")
return f"Error generating image (image-to-image): {str(e)}"
def generate_video_from_image(self, input_image_data, prompt: str,
session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate video from input image using Lightricks LTX-Video"""
try:
print("[Image2Video] Starting video generation")
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
# Process input image
pil_image = self._process_input_image(input_image_data)
print(f"[Image2Video] Input image size: {pil_image.size}")
# Compress image for API limits
input_bytes = self._compress_image_for_video(pil_image, max_size_mb=3.9)
# Check for image-to-video method
image_to_video_method = getattr(self.hf_client, "image_to_video", None)
if not callable(image_to_video_method):
return ("Error: Your huggingface_hub version does not support image_to_video. "
"Please upgrade with `pip install -U huggingface_hub`")
model_id = "Lightricks/LTX-Video-0.9.8-13B-distilled"
print(f"[Image2Video] Calling API with model: {model_id}")
video_bytes = image_to_video_method(
input_bytes,
prompt=prompt,
model=model_id,
)
print(f"[Image2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}")
# Create temporary URL
filename = "image_to_video_result.mp4"
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
video_html = self._create_video_html(temp_url)
if not validate_video_html(video_html):
return "Error: Generated video HTML is malformed"
print(f"[Image2Video] Successfully generated video: {temp_url}")
return video_html
except Exception as e:
print(f"[Image2Video] Error: {str(e)}")
return f"Error generating video (image-to-video): {str(e)}"
def generate_video_from_text(self, prompt: str, session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate video from text prompt using Wan-AI text-to-video model"""
try:
print("[Text2Video] Starting video generation")
if not self.hf_client:
return "Error: HF_TOKEN environment variable is not set."
# Check for text-to-video method
text_to_video_method = getattr(self.hf_client, "text_to_video", None)
if not callable(text_to_video_method):
return ("Error: Your huggingface_hub version does not support text_to_video. "
"Please upgrade with `pip install -U huggingface_hub`")
model_id = "Wan-AI/Wan2.2-T2V-A14B"
prompt_str = (prompt or "").strip()
print(f"[Text2Video] Using model: {model_id}, prompt length: {len(prompt_str)}")
video_bytes = text_to_video_method(
prompt_str,
model=model_id,
)
print(f"[Text2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}")
# Create temporary URL
filename = "text_to_video_result.mp4"
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
video_html = self._create_video_html(temp_url)
if not validate_video_html(video_html):
return "Error: Generated video HTML is malformed"
print(f"[Text2Video] Successfully generated video: {temp_url}")
return video_html
except Exception as e:
print(f"[Text2Video] Error: {str(e)}")
return f"Error generating video (text-to-video): {str(e)}"
def generate_music_from_text(self, prompt: str, music_length_ms: int = 30000,
session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
"""Generate music using ElevenLabs Music API"""
try:
api_key = os.getenv('ELEVENLABS_API_KEY')
if not api_key:
return "Error: ELEVENLABS_API_KEY environment variable is not set."
print(f"[MusicGen] Generating music: {prompt}")
headers = {
'Content-Type': 'application/json',
'xi-api-key': api_key,
}
payload = {
'prompt': prompt or 'Epic orchestral theme with soaring strings and powerful brass',
'music_length_ms': int(music_length_ms) if music_length_ms else 30000,
}
resp = requests.post(
'https://api.elevenlabs.io/v1/music/compose',
headers=headers,
json=payload,
timeout=60
)
try:
resp.raise_for_status()
except Exception as e:
error_text = getattr(e, 'response', resp).text if hasattr(e, 'response') else resp.text
return f"Error generating music: {error_text}"
# Create temporary URL
filename = "generated_music.mp3"
temp_url = self._upload_media_to_hf(resp.content, filename, "audio", token, use_temp=True)
if temp_url.startswith("Error"):
return temp_url
audio_html = self._create_audio_html(temp_url)
print(f"[MusicGen] Successfully generated music: {temp_url}")
return audio_html
except Exception as e:
print(f"[MusicGen] Error: {str(e)}")
return f"Error generating music: {str(e)}"
def _process_input_image(self, input_image_data) -> Image.Image:
"""Convert various image formats to PIL Image"""
if hasattr(input_image_data, 'read'):
raw = input_image_data.read()
pil_image = Image.open(io.BytesIO(raw))
elif hasattr(input_image_data, 'mode') and hasattr(input_image_data, 'size'):
pil_image = input_image_data
elif isinstance(input_image_data, np.ndarray):
pil_image = Image.fromarray(input_image_data)
elif isinstance(input_image_data, (bytes, bytearray)):
pil_image = Image.open(io.BytesIO(input_image_data))
else:
pil_image = Image.open(io.BytesIO(bytes(input_image_data)))
# Ensure RGB
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
return pil_image
def _compress_image_for_video(self, pil_image: Image.Image, max_size_mb: float = 3.9) -> bytes:
"""Compress image for video generation API limits"""
MAX_BYTES = int(max_size_mb * 1024 * 1024)
max_dim = 1024
quality = 90
def encode_current(pil: Image.Image, q: int) -> bytes:
tmp = io.BytesIO()
pil.save(tmp, format='JPEG', quality=q, optimize=True)
return tmp.getvalue()
# Downscale while too large
while max(pil_image.size) > max_dim:
ratio = max_dim / float(max(pil_image.size))
new_size = (max(1, int(pil_image.size[0] * ratio)), max(1, int(pil_image.size[1] * ratio)))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
encoded = encode_current(pil_image, quality)
# Reduce quality or dimensions if still too large
while len(encoded) > MAX_BYTES and (quality > 40 or max(pil_image.size) > 640):
if quality > 40:
quality -= 10
else:
new_w = max(1, int(pil_image.size[0] * 0.85))
new_h = max(1, int(pil_image.size[1] * 0.85))
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
encoded = encode_current(pil_image, quality)
return encoded
def _upload_media_to_hf(self, media_bytes: bytes, filename: str, media_type: str,
token: Optional[gr.OAuthToken] = None, use_temp: bool = True) -> str:
"""Upload media to HF or create temporary file"""
if use_temp:
return create_temp_media_url(media_bytes, filename, media_type)
# HF upload logic would go here for permanent URLs
# For now, always use temp files
return create_temp_media_url(media_bytes, filename, media_type)
def _create_video_html(self, video_url: str) -> str:
"""Create HTML video element"""
return f'''<video controls autoplay muted loop playsinline
style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0; display: block;"
onloadstart="this.style.backgroundColor='#f0f0f0'"
onerror="this.style.display='none'; console.error('Video failed to load')">
<source src="{video_url}" type="video/mp4" />
<p style="text-align: center; color: #666;">Your browser does not support the video tag.</p>
</video>'''
def _create_audio_html(self, audio_url: str) -> str:
"""Create HTML audio player"""
return f'''<div class="anycoder-music" style="max-width:420px;margin:16px auto;padding:12px 16px;border:1px solid #e5e7eb;border-radius:12px;background:linear-gradient(180deg,#fafafa,#f3f4f6);box-shadow:0 2px 8px rgba(0,0,0,0.06)">
<div style="font-size:13px;color:#374151;margin-bottom:8px;display:flex;align-items:center;gap:6px">
<span>🎵 Generated music</span>
</div>
<audio controls autoplay loop style="width:100%;outline:none;">
<source src="{audio_url}" type="audio/mpeg" />
Your browser does not support the audio element.
</audio>
</div>'''
# Global media generator instance
media_generator = MediaGenerator()
# Export main functions
def generate_image_with_qwen(prompt: str, image_index: int = 0, token: Optional[gr.OAuthToken] = None) -> str:
return media_generator.generate_image_with_qwen(prompt, image_index, token)
def generate_image_to_image(input_image_data, prompt: str, token: Optional[gr.OAuthToken] = None) -> str:
return media_generator.generate_image_to_image(input_image_data, prompt, token)
def generate_video_from_image(input_image_data, prompt: str, session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
return media_generator.generate_video_from_image(input_image_data, prompt, session_id, token)
def generate_video_from_text(prompt: str, session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
return media_generator.generate_video_from_text(prompt, session_id, token)
def generate_music_from_text(prompt: str, music_length_ms: int = 30000, session_id: Optional[str] = None,
token: Optional[gr.OAuthToken] = None) -> str:
return media_generator.generate_music_from_text(prompt, music_length_ms, session_id, token)