File size: 7,627 Bytes
85a7fa6
a594839
1a8f2c7
dc04565
85a7fa6
ae39b5f
3a53c8d
ae39b5f
1a8f2c7
3a53c8d
 
 
 
 
33ba29f
dc04565
c245dea
1a8f2c7
3a53c8d
1a8f2c7
3a53c8d
3a257f2
90ebabe
3a53c8d
 
3a257f2
90ebabe
 
3a53c8d
 
 
 
 
 
 
 
 
 
c245dea
 
dc04565
c245dea
 
 
3a53c8d
c245dea
 
3a53c8d
 
c245dea
3a53c8d
 
c245dea
 
3a53c8d
 
 
 
 
c245dea
 
 
 
 
 
 
 
 
 
 
3a53c8d
c245dea
3a53c8d
c245dea
 
 
3a53c8d
 
 
 
 
c245dea
 
 
 
 
 
3a53c8d
 
c245dea
3a53c8d
c245dea
 
3a53c8d
c245dea
3a53c8d
c245dea
 
 
3a53c8d
c245dea
 
 
 
 
 
 
 
 
3a53c8d
c245dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc04565
c245dea
 
 
 
 
 
c3f22c6
c245dea
 
 
 
ae39b5f
c245dea
 
 
 
 
 
 
85a7fa6
3a53c8d
c245dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae39b5f
85a7fa6
 
c245dea
 
 
 
 
 
 
 
3a53c8d
c245dea
 
3a53c8d
c0c3ada
c245dea
 
 
ae39b5f
c245dea
 
 
 
 
ae39b5f
 
85a7fa6
c245dea
 
013dbb5
 
ae39b5f
c245dea
ae39b5f
c245dea
 
dc04565
85a7fa6
1a8f2c7
045423f
c245dea
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import gradio as gr
import requests
import base64
import os
import time
import jwt
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ===== API CONFIGURATION =====

API_BASE_URL = "https://api-singapore.klingai.com"
CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image"

# ===== AUTHENTICATION =====
def generate_jwt_token():
    """Generate JWT token for API authentication"""
    payload = {
        "iss": ACCESS_KEY_ID,
        "exp": int(time.time()) + 1800,  # 30 minutes expiration
        "nbf": int(time.time()) - 5      # Not before 5 seconds ago
    }
    return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256")

# ===== IMAGE PROCESSING =====
def prepare_image_base64(image_path):
    """Convert image to base64 without prefix"""
    try:
        with open(image_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode('utf-8')
    except Exception as e:
        logger.error(f"Image processing failed: {str(e)}")
        return None

def validate_image(image_path):
    """Validate image meets API requirements"""
    try:
        # Check file size
        size_mb = os.path.getsize(image_path) / (1024 * 1024)
        if size_mb > 10:
            return False, "Image too large (max 10MB)"
        
        # Check dimensions (basic check - should use PIL for actual dimensions)
        return True, ""
    except Exception as e:
        return False, f"Image validation error: {str(e)}"

# ===== API FUNCTIONS =====
def create_multi_image_task(subject_images, prompt):
    """Create multi-image generation task"""
    headers = {
        "Authorization": f"Bearer {generate_jwt_token()}",
        "Content-Type": "application/json"
    }
    
    # Prepare subject images list
    subject_image_list = []
    for img_path in subject_images:
        if img_path:  # Skip empty/None images
            base64_img = prepare_image_base64(img_path)
            if base64_img:
                subject_image_list.append({"subject_image": base64_img})
    
    if len(subject_image_list) < 2:
        return None, "At least 2 subject images required"
    
    payload = {
        "model_name": "kling-v2",
        "prompt": prompt,
        "subject_image_list": subject_image_list,
        "n": 1,
        "aspect_ratio": "1:1"
    }
    
    try:
        response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
        response.raise_for_status()
        return response.json(), None
    except requests.exceptions.RequestException as e:
        logger.error(f"API request failed: {str(e)}")
        if hasattr(e, 'response') and e.response:
            logger.error(f"API response: {e.response.text}")
        return None, f"API Error: {str(e)}"

def check_task_status(task_id):
    """Check task completion status"""
    headers = {"Authorization": f"Bearer {generate_jwt_token()}"}
    status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}"
    
    try:
        response = requests.get(status_url, headers=headers)
        response.raise_for_status()
        return response.json(), None
    except requests.exceptions.RequestException as e:
        return None, f"Status check failed: {str(e)}"

# ===== MAIN PROCESSING =====
def generate_image(subject_images, prompt):
    """Handle complete image generation workflow"""
    # Validate images
    for img in subject_images:
        if img:  # Only validate non-empty images
            is_valid, error_msg = validate_image(img)
            if not is_valid:
                return None, error_msg
    
    # Create task
    task_response, error = create_multi_image_task(subject_images, prompt)
    if error:
        return None, error
    
    if task_response.get("code") != 0:
        return None, f"API error: {task_response.get('message', 'Unknown error')}"
    
    task_id = task_response["data"]["task_id"]
    logger.info(f"Task created: {task_id}")
    
    # Poll for results (max 10 minutes)
    for _ in range(60):
        task_data, error = check_task_status(task_id)
        if error:
            return None, error
            
        status = task_data["data"]["task_status"]
        
        if status == "succeed":
            image_url = task_data["data"]["task_result"]["images"][0]["url"]
            try:
                response = requests.get(image_url)
                response.raise_for_status()
                output_path = Path(f"/tmp/kling_output_{task_id}.png")
                with open(output_path, "wb") as f:
                    f.write(response.content)
                return str(output_path), None
            except Exception as e:
                return None, f"Failed to download result: {str(e)}"
                
        elif status in ("failed", "canceled"):
            error_msg = task_data["data"].get("task_status_msg", "Unknown error")
            return None, f"Task failed: {error_msg}"
            
        time.sleep(10)
    
    return None, "Task timed out after 10 minutes"

# ===== GRADIO INTERFACE =====
def process_interface(subject_image1, subject_image2, subject_image3, subject_image4, prompt):
    # Filter out None/empty images
    subject_images = [img for img in [subject_image1, subject_image2, subject_image3, subject_image4] if img]
    
    if len(subject_images) < 2:
        return None, None, "Please upload at least 2 subject images"
    
    output_path, error = generate_image(subject_images, prompt)
    if error:
        return None, None, error
    
    return output_path, output_path, "Generation successful!"

with gr.Blocks(title="Kling AI Multi-Image Generator") as app:
    gr.Markdown("## 🖼️ Kling AI Multi-Image to Image")
    gr.Markdown("Combine features from multiple images into one result")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Input Settings")
            with gr.Row():
                subject_image1 = gr.Image(type="filepath", label="Subject Image 1 *")
                subject_image2 = gr.Image(type="filepath", label="Subject Image 2 *")
            with gr.Row():
                subject_image3 = gr.Image(type="filepath", label="Subject Image 3 (Optional)")
                subject_image4 = gr.Image(type="filepath", label="Subject Image 4 (Optional)")
            
            prompt_input = gr.Textbox(
                label="Transformation Prompt",
                placeholder="Describe how to combine these images"
            )
            
            generate_btn = gr.Button("Generate", variant="primary")
            
            gr.Markdown("### Requirements (* = required)")
            gr.Markdown("""
            - **At least 2 subject images** (marked with *)
            - Max 4 images total
            - Max size: 10MB per image
            - Formats: JPG, PNG
            - Min dimensions: 300x300px
            """)
            
        with gr.Column():
            gr.Markdown("### Output")
            output_image = gr.Image(label="Generated Image", interactive=False, height=400)
            output_file = gr.File(label="Download Result")
            status_output = gr.Textbox(label="Status", interactive=False)
    
    # Modified inputs to accept individual components
    generate_btn.click(
        fn=process_interface,
        inputs=[subject_image1, subject_image2, subject_image3, subject_image4, prompt_input],
        outputs=[output_image, output_file, status_output]
    )

if __name__ == "__main__":
    app.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )