import gradio as gr import requests import base64 import os import time import jwt import logging from pathlib import Path # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ===== API CONFIGURATION ===== API_BASE_URL = "https://api-singapore.klingai.com" CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image" # ===== AUTHENTICATION ===== def generate_jwt_token(): """Generate JWT token for API authentication""" payload = { "iss": ACCESS_KEY_ID, "exp": int(time.time()) + 1800, # 30 minutes expiration "nbf": int(time.time()) - 5 # Not before 5 seconds ago } return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256") # ===== IMAGE PROCESSING ===== def prepare_image_base64(image_path): """Convert image to base64 without prefix""" try: with open(image_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode('utf-8') except Exception as e: logger.error(f"Image processing failed: {str(e)}") return None def validate_image(image_path): """Validate image meets API requirements""" try: # Check file size size_mb = os.path.getsize(image_path) / (1024 * 1024) if size_mb > 10: return False, "Image too large (max 10MB)" # Check dimensions (basic check - should use PIL for actual dimensions) return True, "" except Exception as e: return False, f"Image validation error: {str(e)}" # ===== API FUNCTIONS ===== def create_multi_image_task(subject_images, prompt): """Create multi-image generation task""" headers = { "Authorization": f"Bearer {generate_jwt_token()}", "Content-Type": "application/json" } # Prepare subject images list subject_image_list = [] for img_path in subject_images: if img_path: # Skip empty/None images base64_img = prepare_image_base64(img_path) if base64_img: subject_image_list.append({"subject_image": base64_img}) if len(subject_image_list) < 2: return None, "At least 2 subject images required" payload = { "model_name": "kling-v2", "prompt": prompt, "subject_image_list": subject_image_list, "n": 1, "aspect_ratio": "1:1" } try: response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers) response.raise_for_status() return response.json(), None except requests.exceptions.RequestException as e: logger.error(f"API request failed: {str(e)}") if hasattr(e, 'response') and e.response: logger.error(f"API response: {e.response.text}") return None, f"API Error: {str(e)}" def check_task_status(task_id): """Check task completion status""" headers = {"Authorization": f"Bearer {generate_jwt_token()}"} status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}" try: response = requests.get(status_url, headers=headers) response.raise_for_status() return response.json(), None except requests.exceptions.RequestException as e: return None, f"Status check failed: {str(e)}" # ===== MAIN PROCESSING ===== def generate_image(subject_images, prompt): """Handle complete image generation workflow""" # Validate images for img in subject_images: if img: # Only validate non-empty images is_valid, error_msg = validate_image(img) if not is_valid: return None, error_msg # Create task task_response, error = create_multi_image_task(subject_images, prompt) if error: return None, error if task_response.get("code") != 0: return None, f"API error: {task_response.get('message', 'Unknown error')}" task_id = task_response["data"]["task_id"] logger.info(f"Task created: {task_id}") # Poll for results (max 10 minutes) for _ in range(60): task_data, error = check_task_status(task_id) if error: return None, error status = task_data["data"]["task_status"] if status == "succeed": image_url = task_data["data"]["task_result"]["images"][0]["url"] try: response = requests.get(image_url) response.raise_for_status() output_path = Path(f"/tmp/kling_output_{task_id}.png") with open(output_path, "wb") as f: f.write(response.content) return str(output_path), None except Exception as e: return None, f"Failed to download result: {str(e)}" elif status in ("failed", "canceled"): error_msg = task_data["data"].get("task_status_msg", "Unknown error") return None, f"Task failed: {error_msg}" time.sleep(10) return None, "Task timed out after 10 minutes" # ===== GRADIO INTERFACE ===== def process_interface(subject_image1, subject_image2, subject_image3, subject_image4, prompt): # Filter out None/empty images subject_images = [img for img in [subject_image1, subject_image2, subject_image3, subject_image4] if img] if len(subject_images) < 2: return None, None, "Please upload at least 2 subject images" output_path, error = generate_image(subject_images, prompt) if error: return None, None, error return output_path, output_path, "Generation successful!" with gr.Blocks(title="Kling AI Multi-Image Generator") as app: gr.Markdown("## 🖼️ Kling AI Multi-Image to Image") gr.Markdown("Combine features from multiple images into one result") with gr.Row(): with gr.Column(): gr.Markdown("### Input Settings") with gr.Row(): subject_image1 = gr.Image(type="filepath", label="Subject Image 1 *") subject_image2 = gr.Image(type="filepath", label="Subject Image 2 *") with gr.Row(): subject_image3 = gr.Image(type="filepath", label="Subject Image 3 (Optional)") subject_image4 = gr.Image(type="filepath", label="Subject Image 4 (Optional)") prompt_input = gr.Textbox( label="Transformation Prompt", placeholder="Describe how to combine these images" ) generate_btn = gr.Button("Generate", variant="primary") gr.Markdown("### Requirements (* = required)") gr.Markdown(""" - **At least 2 subject images** (marked with *) - Max 4 images total - Max size: 10MB per image - Formats: JPG, PNG - Min dimensions: 300x300px """) with gr.Column(): gr.Markdown("### Output") output_image = gr.Image(label="Generated Image", interactive=False, height=400) output_file = gr.File(label="Download Result") status_output = gr.Textbox(label="Status", interactive=False) # Modified inputs to accept individual components generate_btn.click( fn=process_interface, inputs=[subject_image1, subject_image2, subject_image3, subject_image4, prompt_input], outputs=[output_image, output_file, status_output] ) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, share=False )