|
|
|
import os |
|
import io |
|
import base64 |
|
import requests |
|
import tempfile |
|
from typing import Optional, Dict, Any |
|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
|
|
from huggingface_hub import InferenceClient |
|
from utils import create_temp_media_url, compress_media_for_data_uri, validate_video_html |
|
from config import HF_TOKEN |
|
|
|
class MediaGenerator: |
|
"""Handles generation of images, videos, and music""" |
|
|
|
def __init__(self): |
|
self.hf_client = None |
|
if HF_TOKEN: |
|
self.hf_client = InferenceClient( |
|
provider="auto", |
|
api_key=HF_TOKEN, |
|
bill_to="huggingface" |
|
) |
|
|
|
def generate_image_with_qwen(self, prompt: str, image_index: int = 0, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
"""Generate image using Qwen image model""" |
|
try: |
|
if not self.hf_client: |
|
return "Error: HF_TOKEN environment variable is not set." |
|
|
|
print(f"[ImageGen] Generating image with prompt: {prompt}") |
|
|
|
|
|
image = self.hf_client.text_to_image( |
|
prompt, |
|
model="Qwen/Qwen-Image", |
|
) |
|
|
|
|
|
max_size = 1024 |
|
if image.width > max_size or image.height > max_size: |
|
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) |
|
|
|
|
|
buffer = io.BytesIO() |
|
image.convert('RGB').save(buffer, format='JPEG', quality=90, optimize=True) |
|
image_bytes = buffer.getvalue() |
|
|
|
|
|
filename = f"generated_image_{image_index}.jpg" |
|
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True) |
|
|
|
if temp_url.startswith("Error"): |
|
return temp_url |
|
|
|
return f'<img src="{temp_url}" alt="{prompt}" style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0;" loading="lazy" />' |
|
|
|
except Exception as e: |
|
print(f"Image generation error: {str(e)}") |
|
return f"Error generating image: {str(e)}" |
|
|
|
def generate_image_to_image(self, input_image_data, prompt: str, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
"""Generate image using image-to-image with Qwen-Image-Edit""" |
|
try: |
|
if not self.hf_client: |
|
return "Error: HF_TOKEN environment variable is not set." |
|
|
|
print(f"[Image2Image] Processing with prompt: {prompt}") |
|
|
|
|
|
pil_image = self._process_input_image(input_image_data) |
|
|
|
|
|
max_input_size = 1024 |
|
if pil_image.width > max_input_size or pil_image.height > max_input_size: |
|
pil_image.thumbnail((max_input_size, max_input_size), Image.Resampling.LANCZOS) |
|
|
|
|
|
buf = io.BytesIO() |
|
pil_image.save(buf, format='JPEG', quality=85, optimize=True) |
|
input_bytes = buf.getvalue() |
|
|
|
|
|
image = self.hf_client.image_to_image( |
|
input_bytes, |
|
prompt=prompt, |
|
model="Qwen/Qwen-Image-Edit", |
|
) |
|
|
|
|
|
max_size = 1024 |
|
if image.width > max_size or image.height > max_size: |
|
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) |
|
|
|
out_buf = io.BytesIO() |
|
image.convert('RGB').save(out_buf, format='JPEG', quality=90, optimize=True) |
|
image_bytes = out_buf.getvalue() |
|
|
|
|
|
filename = "image_to_image_result.jpg" |
|
temp_url = self._upload_media_to_hf(image_bytes, filename, "image", token, use_temp=True) |
|
|
|
if temp_url.startswith("Error"): |
|
return temp_url |
|
|
|
return f'<img src="{temp_url}" alt="{prompt}" style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0;" loading="lazy" />' |
|
|
|
except Exception as e: |
|
print(f"Image-to-image generation error: {str(e)}") |
|
return f"Error generating image (image-to-image): {str(e)}" |
|
|
|
def generate_video_from_image(self, input_image_data, prompt: str, |
|
session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
"""Generate video from input image using Lightricks LTX-Video""" |
|
try: |
|
print("[Image2Video] Starting video generation") |
|
if not self.hf_client: |
|
return "Error: HF_TOKEN environment variable is not set." |
|
|
|
|
|
pil_image = self._process_input_image(input_image_data) |
|
print(f"[Image2Video] Input image size: {pil_image.size}") |
|
|
|
|
|
input_bytes = self._compress_image_for_video(pil_image, max_size_mb=3.9) |
|
|
|
|
|
image_to_video_method = getattr(self.hf_client, "image_to_video", None) |
|
if not callable(image_to_video_method): |
|
return ("Error: Your huggingface_hub version does not support image_to_video. " |
|
"Please upgrade with `pip install -U huggingface_hub`") |
|
|
|
model_id = "Lightricks/LTX-Video-0.9.8-13B-distilled" |
|
print(f"[Image2Video] Calling API with model: {model_id}") |
|
|
|
video_bytes = image_to_video_method( |
|
input_bytes, |
|
prompt=prompt, |
|
model=model_id, |
|
) |
|
|
|
print(f"[Image2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}") |
|
|
|
|
|
filename = "image_to_video_result.mp4" |
|
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True) |
|
|
|
if temp_url.startswith("Error"): |
|
return temp_url |
|
|
|
video_html = self._create_video_html(temp_url) |
|
|
|
if not validate_video_html(video_html): |
|
return "Error: Generated video HTML is malformed" |
|
|
|
print(f"[Image2Video] Successfully generated video: {temp_url}") |
|
return video_html |
|
|
|
except Exception as e: |
|
print(f"[Image2Video] Error: {str(e)}") |
|
return f"Error generating video (image-to-video): {str(e)}" |
|
|
|
def generate_video_from_text(self, prompt: str, session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
"""Generate video from text prompt using Wan-AI text-to-video model""" |
|
try: |
|
print("[Text2Video] Starting video generation") |
|
if not self.hf_client: |
|
return "Error: HF_TOKEN environment variable is not set." |
|
|
|
|
|
text_to_video_method = getattr(self.hf_client, "text_to_video", None) |
|
if not callable(text_to_video_method): |
|
return ("Error: Your huggingface_hub version does not support text_to_video. " |
|
"Please upgrade with `pip install -U huggingface_hub`") |
|
|
|
model_id = "Wan-AI/Wan2.2-T2V-A14B" |
|
prompt_str = (prompt or "").strip() |
|
print(f"[Text2Video] Using model: {model_id}, prompt length: {len(prompt_str)}") |
|
|
|
video_bytes = text_to_video_method( |
|
prompt_str, |
|
model=model_id, |
|
) |
|
|
|
print(f"[Text2Video] Received video bytes: {len(video_bytes) if hasattr(video_bytes, '__len__') else 'unknown'}") |
|
|
|
|
|
filename = "text_to_video_result.mp4" |
|
temp_url = self._upload_media_to_hf(video_bytes, filename, "video", token, use_temp=True) |
|
|
|
if temp_url.startswith("Error"): |
|
return temp_url |
|
|
|
video_html = self._create_video_html(temp_url) |
|
|
|
if not validate_video_html(video_html): |
|
return "Error: Generated video HTML is malformed" |
|
|
|
print(f"[Text2Video] Successfully generated video: {temp_url}") |
|
return video_html |
|
|
|
except Exception as e: |
|
print(f"[Text2Video] Error: {str(e)}") |
|
return f"Error generating video (text-to-video): {str(e)}" |
|
|
|
def generate_music_from_text(self, prompt: str, music_length_ms: int = 30000, |
|
session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
"""Generate music using ElevenLabs Music API""" |
|
try: |
|
api_key = os.getenv('ELEVENLABS_API_KEY') |
|
if not api_key: |
|
return "Error: ELEVENLABS_API_KEY environment variable is not set." |
|
|
|
print(f"[MusicGen] Generating music: {prompt}") |
|
|
|
headers = { |
|
'Content-Type': 'application/json', |
|
'xi-api-key': api_key, |
|
} |
|
|
|
payload = { |
|
'prompt': prompt or 'Epic orchestral theme with soaring strings and powerful brass', |
|
'music_length_ms': int(music_length_ms) if music_length_ms else 30000, |
|
} |
|
|
|
resp = requests.post( |
|
'https://api.elevenlabs.io/v1/music/compose', |
|
headers=headers, |
|
json=payload, |
|
timeout=60 |
|
) |
|
|
|
try: |
|
resp.raise_for_status() |
|
except Exception as e: |
|
error_text = getattr(e, 'response', resp).text if hasattr(e, 'response') else resp.text |
|
return f"Error generating music: {error_text}" |
|
|
|
|
|
filename = "generated_music.mp3" |
|
temp_url = self._upload_media_to_hf(resp.content, filename, "audio", token, use_temp=True) |
|
|
|
if temp_url.startswith("Error"): |
|
return temp_url |
|
|
|
audio_html = self._create_audio_html(temp_url) |
|
print(f"[MusicGen] Successfully generated music: {temp_url}") |
|
return audio_html |
|
|
|
except Exception as e: |
|
print(f"[MusicGen] Error: {str(e)}") |
|
return f"Error generating music: {str(e)}" |
|
|
|
def _process_input_image(self, input_image_data) -> Image.Image: |
|
"""Convert various image formats to PIL Image""" |
|
if hasattr(input_image_data, 'read'): |
|
raw = input_image_data.read() |
|
pil_image = Image.open(io.BytesIO(raw)) |
|
elif hasattr(input_image_data, 'mode') and hasattr(input_image_data, 'size'): |
|
pil_image = input_image_data |
|
elif isinstance(input_image_data, np.ndarray): |
|
pil_image = Image.fromarray(input_image_data) |
|
elif isinstance(input_image_data, (bytes, bytearray)): |
|
pil_image = Image.open(io.BytesIO(input_image_data)) |
|
else: |
|
pil_image = Image.open(io.BytesIO(bytes(input_image_data))) |
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
pil_image = pil_image.convert('RGB') |
|
|
|
return pil_image |
|
|
|
def _compress_image_for_video(self, pil_image: Image.Image, max_size_mb: float = 3.9) -> bytes: |
|
"""Compress image for video generation API limits""" |
|
MAX_BYTES = int(max_size_mb * 1024 * 1024) |
|
max_dim = 1024 |
|
quality = 90 |
|
|
|
def encode_current(pil: Image.Image, q: int) -> bytes: |
|
tmp = io.BytesIO() |
|
pil.save(tmp, format='JPEG', quality=q, optimize=True) |
|
return tmp.getvalue() |
|
|
|
|
|
while max(pil_image.size) > max_dim: |
|
ratio = max_dim / float(max(pil_image.size)) |
|
new_size = (max(1, int(pil_image.size[0] * ratio)), max(1, int(pil_image.size[1] * ratio))) |
|
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
encoded = encode_current(pil_image, quality) |
|
|
|
|
|
while len(encoded) > MAX_BYTES and (quality > 40 or max(pil_image.size) > 640): |
|
if quality > 40: |
|
quality -= 10 |
|
else: |
|
new_w = max(1, int(pil_image.size[0] * 0.85)) |
|
new_h = max(1, int(pil_image.size[1] * 0.85)) |
|
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS) |
|
encoded = encode_current(pil_image, quality) |
|
|
|
return encoded |
|
|
|
def _upload_media_to_hf(self, media_bytes: bytes, filename: str, media_type: str, |
|
token: Optional[gr.OAuthToken] = None, use_temp: bool = True) -> str: |
|
"""Upload media to HF or create temporary file""" |
|
if use_temp: |
|
return create_temp_media_url(media_bytes, filename, media_type) |
|
|
|
|
|
|
|
return create_temp_media_url(media_bytes, filename, media_type) |
|
|
|
def _create_video_html(self, video_url: str) -> str: |
|
"""Create HTML video element""" |
|
return f'''<video controls autoplay muted loop playsinline |
|
style="max-width: 100%; height: auto; border-radius: 8px; margin: 10px 0; display: block;" |
|
onloadstart="this.style.backgroundColor='#f0f0f0'" |
|
onerror="this.style.display='none'; console.error('Video failed to load')"> |
|
<source src="{video_url}" type="video/mp4" /> |
|
<p style="text-align: center; color: #666;">Your browser does not support the video tag.</p> |
|
</video>''' |
|
|
|
def _create_audio_html(self, audio_url: str) -> str: |
|
"""Create HTML audio player""" |
|
return f'''<div class="anycoder-music" style="max-width:420px;margin:16px auto;padding:12px 16px;border:1px solid #e5e7eb;border-radius:12px;background:linear-gradient(180deg,#fafafa,#f3f4f6);box-shadow:0 2px 8px rgba(0,0,0,0.06)"> |
|
<div style="font-size:13px;color:#374151;margin-bottom:8px;display:flex;align-items:center;gap:6px"> |
|
<span>🎵 Generated music</span> |
|
</div> |
|
<audio controls autoplay loop style="width:100%;outline:none;"> |
|
<source src="{audio_url}" type="audio/mpeg" /> |
|
Your browser does not support the audio element. |
|
</audio> |
|
</div>''' |
|
|
|
|
|
media_generator = MediaGenerator() |
|
|
|
|
|
def generate_image_with_qwen(prompt: str, image_index: int = 0, token: Optional[gr.OAuthToken] = None) -> str: |
|
return media_generator.generate_image_with_qwen(prompt, image_index, token) |
|
|
|
def generate_image_to_image(input_image_data, prompt: str, token: Optional[gr.OAuthToken] = None) -> str: |
|
return media_generator.generate_image_to_image(input_image_data, prompt, token) |
|
|
|
def generate_video_from_image(input_image_data, prompt: str, session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
return media_generator.generate_video_from_image(input_image_data, prompt, session_id, token) |
|
|
|
def generate_video_from_text(prompt: str, session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
return media_generator.generate_video_from_text(prompt, session_id, token) |
|
|
|
def generate_music_from_text(prompt: str, music_length_ms: int = 30000, session_id: Optional[str] = None, |
|
token: Optional[gr.OAuthToken] = None) -> str: |
|
return media_generator.generate_music_from_text(prompt, music_length_ms, session_id, token) |