editableweb / app.py
AkashKumarave's picture
Update app.py
33ba29f verified
import gradio as gr
import requests
import base64
import os
import time
import jwt
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===== API CONFIGURATION =====
API_BASE_URL = "https://api-singapore.klingai.com"
CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image"
# ===== AUTHENTICATION =====
def generate_jwt_token():
"""Generate JWT token for API authentication"""
payload = {
"iss": ACCESS_KEY_ID,
"exp": int(time.time()) + 1800, # 30 minutes expiration
"nbf": int(time.time()) - 5 # Not before 5 seconds ago
}
return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256")
# ===== IMAGE PROCESSING =====
def prepare_image_base64(image_path):
"""Convert image to base64 without prefix"""
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"Image processing failed: {str(e)}")
return None
def validate_image(image_path):
"""Validate image meets API requirements"""
try:
# Check file size
size_mb = os.path.getsize(image_path) / (1024 * 1024)
if size_mb > 10:
return False, "Image too large (max 10MB)"
# Check dimensions (basic check - should use PIL for actual dimensions)
return True, ""
except Exception as e:
return False, f"Image validation error: {str(e)}"
# ===== API FUNCTIONS =====
def create_multi_image_task(subject_images, prompt):
"""Create multi-image generation task"""
headers = {
"Authorization": f"Bearer {generate_jwt_token()}",
"Content-Type": "application/json"
}
# Prepare subject images list
subject_image_list = []
for img_path in subject_images:
if img_path: # Skip empty/None images
base64_img = prepare_image_base64(img_path)
if base64_img:
subject_image_list.append({"subject_image": base64_img})
if len(subject_image_list) < 2:
return None, "At least 2 subject images required"
payload = {
"model_name": "kling-v2",
"prompt": prompt,
"subject_image_list": subject_image_list,
"n": 1,
"aspect_ratio": "1:1"
}
try:
response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
response.raise_for_status()
return response.json(), None
except requests.exceptions.RequestException as e:
logger.error(f"API request failed: {str(e)}")
if hasattr(e, 'response') and e.response:
logger.error(f"API response: {e.response.text}")
return None, f"API Error: {str(e)}"
def check_task_status(task_id):
"""Check task completion status"""
headers = {"Authorization": f"Bearer {generate_jwt_token()}"}
status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}"
try:
response = requests.get(status_url, headers=headers)
response.raise_for_status()
return response.json(), None
except requests.exceptions.RequestException as e:
return None, f"Status check failed: {str(e)}"
# ===== MAIN PROCESSING =====
def generate_image(subject_images, prompt):
"""Handle complete image generation workflow"""
# Validate images
for img in subject_images:
if img: # Only validate non-empty images
is_valid, error_msg = validate_image(img)
if not is_valid:
return None, error_msg
# Create task
task_response, error = create_multi_image_task(subject_images, prompt)
if error:
return None, error
if task_response.get("code") != 0:
return None, f"API error: {task_response.get('message', 'Unknown error')}"
task_id = task_response["data"]["task_id"]
logger.info(f"Task created: {task_id}")
# Poll for results (max 10 minutes)
for _ in range(60):
task_data, error = check_task_status(task_id)
if error:
return None, error
status = task_data["data"]["task_status"]
if status == "succeed":
image_url = task_data["data"]["task_result"]["images"][0]["url"]
try:
response = requests.get(image_url)
response.raise_for_status()
output_path = Path(f"/tmp/kling_output_{task_id}.png")
with open(output_path, "wb") as f:
f.write(response.content)
return str(output_path), None
except Exception as e:
return None, f"Failed to download result: {str(e)}"
elif status in ("failed", "canceled"):
error_msg = task_data["data"].get("task_status_msg", "Unknown error")
return None, f"Task failed: {error_msg}"
time.sleep(10)
return None, "Task timed out after 10 minutes"
# ===== GRADIO INTERFACE =====
def process_interface(subject_image1, subject_image2, subject_image3, subject_image4, prompt):
# Filter out None/empty images
subject_images = [img for img in [subject_image1, subject_image2, subject_image3, subject_image4] if img]
if len(subject_images) < 2:
return None, None, "Please upload at least 2 subject images"
output_path, error = generate_image(subject_images, prompt)
if error:
return None, None, error
return output_path, output_path, "Generation successful!"
with gr.Blocks(title="Kling AI Multi-Image Generator") as app:
gr.Markdown("## 🖼️ Kling AI Multi-Image to Image")
gr.Markdown("Combine features from multiple images into one result")
with gr.Row():
with gr.Column():
gr.Markdown("### Input Settings")
with gr.Row():
subject_image1 = gr.Image(type="filepath", label="Subject Image 1 *")
subject_image2 = gr.Image(type="filepath", label="Subject Image 2 *")
with gr.Row():
subject_image3 = gr.Image(type="filepath", label="Subject Image 3 (Optional)")
subject_image4 = gr.Image(type="filepath", label="Subject Image 4 (Optional)")
prompt_input = gr.Textbox(
label="Transformation Prompt",
placeholder="Describe how to combine these images"
)
generate_btn = gr.Button("Generate", variant="primary")
gr.Markdown("### Requirements (* = required)")
gr.Markdown("""
- **At least 2 subject images** (marked with *)
- Max 4 images total
- Max size: 10MB per image
- Formats: JPG, PNG
- Min dimensions: 300x300px
""")
with gr.Column():
gr.Markdown("### Output")
output_image = gr.Image(label="Generated Image", interactive=False, height=400)
output_file = gr.File(label="Download Result")
status_output = gr.Textbox(label="Status", interactive=False)
# Modified inputs to accept individual components
generate_btn.click(
fn=process_interface,
inputs=[subject_image1, subject_image2, subject_image3, subject_image4, prompt_input],
outputs=[output_image, output_file, status_output]
)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)