Spaces:
Running
Running
import gradio as gr | |
import requests | |
import base64 | |
import os | |
import time | |
import jwt | |
import logging | |
from pathlib import Path | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ===== API CONFIGURATION ===== | |
API_BASE_URL = "https://api-singapore.klingai.com" | |
CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image" | |
# ===== AUTHENTICATION ===== | |
def generate_jwt_token(): | |
"""Generate JWT token for API authentication""" | |
payload = { | |
"iss": ACCESS_KEY_ID, | |
"exp": int(time.time()) + 1800, # 30 minutes expiration | |
"nbf": int(time.time()) - 5 # Not before 5 seconds ago | |
} | |
return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256") | |
# ===== IMAGE PROCESSING ===== | |
def prepare_image_base64(image_path): | |
"""Convert image to base64 without prefix""" | |
try: | |
with open(image_path, "rb") as img_file: | |
return base64.b64encode(img_file.read()).decode('utf-8') | |
except Exception as e: | |
logger.error(f"Image processing failed: {str(e)}") | |
return None | |
def validate_image(image_path): | |
"""Validate image meets API requirements""" | |
try: | |
# Check file size | |
size_mb = os.path.getsize(image_path) / (1024 * 1024) | |
if size_mb > 10: | |
return False, "Image too large (max 10MB)" | |
# Check dimensions (basic check - should use PIL for actual dimensions) | |
return True, "" | |
except Exception as e: | |
return False, f"Image validation error: {str(e)}" | |
# ===== API FUNCTIONS ===== | |
def create_multi_image_task(subject_images, prompt): | |
"""Create multi-image generation task""" | |
headers = { | |
"Authorization": f"Bearer {generate_jwt_token()}", | |
"Content-Type": "application/json" | |
} | |
# Prepare subject images list | |
subject_image_list = [] | |
for img_path in subject_images: | |
if img_path: # Skip empty/None images | |
base64_img = prepare_image_base64(img_path) | |
if base64_img: | |
subject_image_list.append({"subject_image": base64_img}) | |
if len(subject_image_list) < 2: | |
return None, "At least 2 subject images required" | |
payload = { | |
"model_name": "kling-v2", | |
"prompt": prompt, | |
"subject_image_list": subject_image_list, | |
"n": 1, | |
"aspect_ratio": "1:1" | |
} | |
try: | |
response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers) | |
response.raise_for_status() | |
return response.json(), None | |
except requests.exceptions.RequestException as e: | |
logger.error(f"API request failed: {str(e)}") | |
if hasattr(e, 'response') and e.response: | |
logger.error(f"API response: {e.response.text}") | |
return None, f"API Error: {str(e)}" | |
def check_task_status(task_id): | |
"""Check task completion status""" | |
headers = {"Authorization": f"Bearer {generate_jwt_token()}"} | |
status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}" | |
try: | |
response = requests.get(status_url, headers=headers) | |
response.raise_for_status() | |
return response.json(), None | |
except requests.exceptions.RequestException as e: | |
return None, f"Status check failed: {str(e)}" | |
# ===== MAIN PROCESSING ===== | |
def generate_image(subject_images, prompt): | |
"""Handle complete image generation workflow""" | |
# Validate images | |
for img in subject_images: | |
if img: # Only validate non-empty images | |
is_valid, error_msg = validate_image(img) | |
if not is_valid: | |
return None, error_msg | |
# Create task | |
task_response, error = create_multi_image_task(subject_images, prompt) | |
if error: | |
return None, error | |
if task_response.get("code") != 0: | |
return None, f"API error: {task_response.get('message', 'Unknown error')}" | |
task_id = task_response["data"]["task_id"] | |
logger.info(f"Task created: {task_id}") | |
# Poll for results (max 10 minutes) | |
for _ in range(60): | |
task_data, error = check_task_status(task_id) | |
if error: | |
return None, error | |
status = task_data["data"]["task_status"] | |
if status == "succeed": | |
image_url = task_data["data"]["task_result"]["images"][0]["url"] | |
try: | |
response = requests.get(image_url) | |
response.raise_for_status() | |
output_path = Path(f"/tmp/kling_output_{task_id}.png") | |
with open(output_path, "wb") as f: | |
f.write(response.content) | |
return str(output_path), None | |
except Exception as e: | |
return None, f"Failed to download result: {str(e)}" | |
elif status in ("failed", "canceled"): | |
error_msg = task_data["data"].get("task_status_msg", "Unknown error") | |
return None, f"Task failed: {error_msg}" | |
time.sleep(10) | |
return None, "Task timed out after 10 minutes" | |
# ===== GRADIO INTERFACE ===== | |
def process_interface(subject_image1, subject_image2, subject_image3, subject_image4, prompt): | |
# Filter out None/empty images | |
subject_images = [img for img in [subject_image1, subject_image2, subject_image3, subject_image4] if img] | |
if len(subject_images) < 2: | |
return None, None, "Please upload at least 2 subject images" | |
output_path, error = generate_image(subject_images, prompt) | |
if error: | |
return None, None, error | |
return output_path, output_path, "Generation successful!" | |
with gr.Blocks(title="Kling AI Multi-Image Generator") as app: | |
gr.Markdown("## 🖼️ Kling AI Multi-Image to Image") | |
gr.Markdown("Combine features from multiple images into one result") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Input Settings") | |
with gr.Row(): | |
subject_image1 = gr.Image(type="filepath", label="Subject Image 1 *") | |
subject_image2 = gr.Image(type="filepath", label="Subject Image 2 *") | |
with gr.Row(): | |
subject_image3 = gr.Image(type="filepath", label="Subject Image 3 (Optional)") | |
subject_image4 = gr.Image(type="filepath", label="Subject Image 4 (Optional)") | |
prompt_input = gr.Textbox( | |
label="Transformation Prompt", | |
placeholder="Describe how to combine these images" | |
) | |
generate_btn = gr.Button("Generate", variant="primary") | |
gr.Markdown("### Requirements (* = required)") | |
gr.Markdown(""" | |
- **At least 2 subject images** (marked with *) | |
- Max 4 images total | |
- Max size: 10MB per image | |
- Formats: JPG, PNG | |
- Min dimensions: 300x300px | |
""") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
output_image = gr.Image(label="Generated Image", interactive=False, height=400) | |
output_file = gr.File(label="Download Result") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Modified inputs to accept individual components | |
generate_btn.click( | |
fn=process_interface, | |
inputs=[subject_image1, subject_image2, subject_image3, subject_image4, prompt_input], | |
outputs=[output_image, output_file, status_output] | |
) | |
if __name__ == "__main__": | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |