File size: 3,873 Bytes
1aabf84
4409dea
 
 
 
90eb4b1
 
4409dea
 
 
 
 
17119de
4409dea
6f92f77
ad5c32f
4409dea
17119de
d8adfd2
eaf5244
 
d8adfd2
17119de
d8adfd2
 
f4cc8cc
d8adfd2
4488911
17119de
 
 
 
4409dea
17119de
4409dea
1aabf84
6f92f77
 
4409dea
6f92f77
 
 
 
9e8b405
17119de
6f92f77
 
 
 
 
 
 
 
 
4409dea
1aabf84
6f92f77
f08dfbf
6f92f77
4409dea
905bc0d
f08dfbf
 
 
905bc0d
f08dfbf
ec2c1d4
 
 
 
 
4409dea
 
1aabf84
 
 
 
2f69dee
a888645
1aabf84
 
90eb4b1
 
 
 
 
 
 
1aabf84
 
4409dea
 
1aabf84
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Dict
import torch
from diffusers import FluxKontextPipeline
from io import BytesIO
import base64
from PIL import Image, ImageOps
import numpy as np  # Added import

class EndpointHandler:
    def __init__(self, path: str = ""):
        print("πŸš€ Initializing Flux Kontext pipeline...")

        # Load base model from Hugging Face
        self.pipe = FluxKontextPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Kontext-dev",
            torch_dtype=torch.bfloat16,
        )

        # Debug available methods on pipeline
        print("πŸ” Available methods on pipeline:", dir(self.pipe))

        # Load your LoRA weights from your Hugging Face repo
        try:
            self.pipe.load_lora_weights(
                "Texttra/BhoriKontext",
                weight_name="Bh0r12.safetensors"
            )
            print("βœ… LoRA weights loaded from Texttra/BhoriKontext/Bh0r12.safetensors.")
        except Exception as e:
            print(f"⚠️ Failed to load LoRA weights: {str(e)}")

        # Move pipeline to GPU if available
        self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
        print("βœ… Model ready with LoRA applied.")

    def __call__(self, data: Dict) -> Dict:
        print("πŸ”§ Received raw data type:", type(data))
        print("πŸ”§ Received raw data content:", data)

        # Defensive parsing
        if isinstance(data, dict):
            prompt = data.get("prompt")
            image_input = data.get("image")

            # If 'inputs' key is used (HF Inference schema)
            if prompt is None and image_input is None:
                inputs = data.get("inputs")
                if isinstance(inputs, dict):
                    prompt = inputs.get("prompt")
                    image_input = inputs.get("image")
                else:
                    return {"error": "Expected 'inputs' to be a JSON object containing 'prompt' and 'image'."}
        else:
            return {"error": "Input payload must be a JSON object."}

        if not prompt:
            return {"error": "Missing 'prompt' in input data."}
        if not image_input:
            return {"error": "Missing 'image' (base64) in input data."}

        # Decode image from base64 and correct orientation
        try:
            image_bytes = base64.b64decode(image_input)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            image = ImageOps.exif_transpose(image)  # Correct EXIF orientation here
        except Exception as e:
            return {"error": f"Failed to decode 'image' as base64: {str(e)}"}

        # Debug prints for prompt and image size
        print(f"πŸ“ Final prompt: {prompt}")
        print(f"πŸ–ΌοΈ Image size: {image.size}")

        # Generate edited image with Kontext
        try:
            output = self.pipe(
                prompt=prompt,
                image=image,
                num_inference_steps=35,
                guidance_scale=4.0
            ).images[0]
            print("🎨 Image generated.")

            # πŸ’‘ HARD CLAMP pixel values to [0, 255] to prevent NaN/black outputs
            output_array = np.array(output)
            output_array = np.clip(output_array, 0, 255).astype(np.uint8)
            output = Image.fromarray(output_array)
            print("πŸ›‘ Hard clamped output pixel values to [0, 255].")

        except Exception as e:
            return {"error": f"Model inference failed: {str(e)}"}

        # Encode output image to base64
        try:
            buffer = BytesIO()
            output.save(buffer, format="PNG")
            base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
            print("βœ… Returning image.")
            return {"image": base64_image}
        except Exception as e:
            return {"error": f"Failed to encode output image: {str(e)}"}