File size: 13,886 Bytes
3e6cc30
d6bf2e7
3e6cc30
389b237
b9ef8fe
3e6cc30
d6bf2e7
16c0efc
6466623
1c6f9cd
3e6cc30
6466623
16c0efc
 
3e6cc30
 
f8caef0
ab28053
3e6cc30
6466623
ab28053
 
d6bf2e7
ab28053
 
d6bf2e7
 
 
880aeaf
ab28053
 
d6bf2e7
 
 
 
 
 
 
 
 
b50cf46
ab28053
389b237
b0a3aec
389b237
 
 
b0a3aec
ef502da
ab28053
ef502da
 
ab28053
 
6466623
b0a3aec
389b237
ab28053
b0a3aec
 
 
 
 
 
389b237
 
4540631
b0a3aec
389b237
ab28053
 
6466623
3e6cc30
b50cf46
53a697e
16c0efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8caef0
3e6cc30
 
 
 
 
 
b50cf46
3e6cc30
 
b50cf46
 
389b237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4540631
389b237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e6cc30
 
 
 
 
b9ef8fe
b50cf46
3e6cc30
 
 
 
 
 
 
ab28053
 
c908ca0
b0a3aec
d6bf2e7
5305fb7
d6bf2e7
b0a3aec
 
 
 
ab28053
b0a3aec
ef32d56
389b237
b0a3aec
389b237
 
d446c1c
389b237
b0a3aec
ab28053
b0a3aec
83c154c
 
53a697e
1e0112e
 
83c154c
d6bf2e7
83c154c
 
 
 
 
 
 
 
 
 
389b237
83c154c
389b237
b50cf46
83c154c
 
 
78ad77d
83c154c
3e6cc30
83c154c
 
 
 
 
 
 
d6bf2e7
83c154c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6bf2e7
3e6cc30
83c154c
 
53a697e
83c154c
bbbc119
83c154c
 
 
389b237
 
 
 
 
 
 
 
 
83c154c
53a697e
83c154c
bbbc119
389b237
3a09df2
b0a3aec
389b237
 
 
 
 
 
 
 
e77f197
389b237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a09df2
389b237
 
 
 
 
 
3a09df2
389b237
 
 
 
 
 
 
83c154c
 
389b237
53a697e
389b237
 
 
3a09df2
2730363
389b237
 
 
3a09df2
389b237
 
dcebda0
 
880aeaf
 
 
 
1f0c608
dcebda0
 
8148403
dcebda0
1f0c608
dcebda0
83c154c
 
d6bf2e7
b50cf46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import os
import tempfile

import fal_client
import gradio as gr
import numpy as np
import requests
from dotenv import load_dotenv

from huggingface_hub import InferenceClient


load_dotenv()

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TOKEN = None
FAL_KEY = None


def download_locally(url: str, local_path: str = "downloaded_file.png") -> str:
    """Download an image or a video from a URL to a local path.
    Args:
        url (str): The URL of the image to download. Must be an http(s) URL.
        local_path (str, optional): The path (including filename) where the file should be saved. Defaults to "downloaded_file.png".
    Returns:
        str: The filesystem path of the saved file – suitable for returning to a **gr.File** output, or as an MCP tool response.
    """
    if local_path == "":
        local_path = "downloaded_file.png"
    response = requests.get(url, timeout=30)
    response.raise_for_status()
    # If the caller passed only a filename, save into a temporary directory to avoid permission issues
    if os.path.dirname(local_path) == "":
        tmp_dir = tempfile.gettempdir()
        local_path = os.path.join(tmp_dir, local_path)
    with open(local_path, "wb") as f:
        f.write(response.content)
    return local_path


def login_hf(oauth_token: gr.OAuthToken | None):
    """
    Login to Hugging Face and check initial key statuses.
    Args:
        oauth_token (gr.OAuthToken | None): The OAuth token from Hugging Face.
    """
    global TOKEN
    if oauth_token and oauth_token.token:
        print("Received OAuth token, logging in...")
        TOKEN = oauth_token.token
    else:
        print("No OAuth token provided, using environment variable HF_TOKEN.")
        TOKEN = os.environ.get("HF_TOKEN")
        print("TOKEN: ", TOKEN)


def login_fal(fal_key_from_ui: str | None):
    """
    Sets the FAL API key from the UI.
    Args:
        fal_key_from_ui (str | None): The FAL key from the UI textbox.
    """
    global FAL_KEY
    if fal_key_from_ui and fal_key_from_ui.strip():
        FAL_KEY = fal_key_from_ui.strip()
        os.environ["FAL_KEY"] = FAL_KEY
        print("FAL_KEY has been set from UI input.")
    else:
        FAL_KEY = os.environ.get("FAL_KEY")
        print("FAL_KEY is configured from environment variable.")
        print("FAL_KEY: ", FAL_KEY)


def generate_image(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024, num_inference_steps: int = 25):
    """
    Generate an image from a prompt.
    Args:
        prompt (str):
                The prompt to generate an image from.
        seed (int, default=42):
            Seed for the random number generator.
        height (int,  default=1024):
            The height in pixels of the output image
        width (int, default=1024):
            The width in pixels of the output image
        num_inference_steps (int, default=25):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
    """
    client = InferenceClient(provider="fal-ai", token=TOKEN)
    image = client.text_to_image(
        prompt=prompt,
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        seed=seed,
        model="black-forest-labs/FLUX.1-dev",
    )
    return image, seed


def generate_video_from_image(
    image_filepath: str,  # This will be the path to the image from gr.Image output
    video_prompt: str,
    duration: str,  # "5" or "10"
    aspect_ratio: str,  # "16:9", "9:16", "1:1"
    video_negative_prompt: str,
    cfg_scale_video: float,
    progress=gr.Progress(track_tqdm=True),
):
    """
    Generates a video from an image using fal-ai/kling-video API.
    """
    if not FAL_KEY:
        gr.Error("FAL_KEY is not set. Cannot generate video.")
        return None
    if not image_filepath:
        gr.Warning("No image provided to generate video from.")
        return None
    if not os.path.exists(image_filepath):
        gr.Error(f"Image file not found at: {image_filepath}")
        return None

    print(f"Video generation started for image: {image_filepath}")
    progress(0, desc="Preparing for video generation...")

    try:
        progress(0.1, desc="Uploading image...")
        print("Uploading image to fal.ai storage...")
        print("FAL_KEY: ", os.environ.get("FAL_KEY"))
        image_url = fal_client.upload_file(image_filepath)
        print(f"Image uploaded, URL: {image_url}")
        progress(0.3, desc="Image uploaded. Submitting video request...")

        def on_queue_update(update):
            if isinstance(update, fal_client.InProgress):
                if update.logs:
                    for log in update.logs:
                        print(f"[fal-ai log] {log['message']}")
                        # Try to update progress description if logs are available
                        # progress(progress.current_progress_value, desc=f"Video processing: {log['message'][:50]}...")

        print("Subscribing to fal-ai/kling-video/v2.1/master/image-to-video...")
        api_result = fal_client.subscribe(
            "fal-ai/kling-video/v2.1/master/image-to-video",
            arguments={
                "prompt": video_prompt,
                "image_url": image_url,
                "duration": duration,
                "aspect_ratio": aspect_ratio,
                "negative_prompt": video_negative_prompt,
                "cfg_scale": cfg_scale_video,
            },
            with_logs=True,  # Get logs
            on_queue_update=on_queue_update,  # Callback for logs
        )

        progress(0.9, desc="Video processing complete.")
        video_output_url = api_result.get("video", {}).get("url")

        if video_output_url:
            print(f"Video generated successfully: {video_output_url}")
            progress(1, desc="Video ready!")
            return video_output_url
        else:
            print(f"Video generation failed or no URL in response. API Result: {api_result}")
            gr.Error("Video generation failed or no video URL returned.")
            return None

    except Exception as e:
        print(f"Error during video generation: {e}")
        gr.Error(f"An error occurred: {str(e)}")
        return None


examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    demo.load(login_hf, inputs=None, outputs=None)
    demo.load(login_fal, inputs=None, outputs=None)
    with gr.Sidebar():
        gr.Markdown("# Authentication")
        gr.Markdown(
            "Sign in with Hugging Face for image generation. Separately, set your fal.ai API Key for image to video generation."
        )

        gr.Markdown("### Hugging Face Login")
        hf_login_button = gr.LoginButton("Sign in with Hugging Face")
        # When hf_login_button is clicked, it provides an OAuthToken or None to the login function.
        hf_login_button.click(fn=login_hf, inputs=[hf_login_button], outputs=None)

        gr.Markdown("### FAL Login (for Image to Video)")
        fal_key_input = gr.Textbox(
            label="FAL API Key",
            placeholder="Enter your FAL API Key here",
            type="password",
            value=os.environ.get("FAL_KEY", ""),  # Pre-fill if loaded from env
        )
        set_fal_key_button = gr.Button("Set FAL Key")
        set_fal_key_button.click(fn=login_fal, inputs=[fal_key_input], outputs=None)

    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """# Text to Image to Video with fal‑ai through HF Inference Providers ⚡\nLearn more about HF Inference Providers [here](https://huggingface.co/docs/inference-providers/index)"""
            """## Text to Image uses [FLUX.1 [dev]](https://fal.ai/models/fal-ai/flux/dev) with fal‑ai through HF Inference Providers"""
            """## Image to Vide uses [kling-video v2.1](https://fal.ai/models/fal-ai/kling-video/v2.1/master/image-to-video/playground) with fal‑ai directly (you will need to set your `FAL_KEY`)."""
        )

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)

        result = gr.Image(label="Generated Image", show_label=False, format="png", type="filepath")
        download_btn = gr.DownloadButton(
            label="Download result image",
            visible=False,
            value=None,
            variant="primary",
        )

        seed_number = gr.Number(label="Seed", precision=0, value=42, interactive=False)

        with gr.Accordion("Advanced Settings", open=False):
            seed_slider = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
            with gr.Row():
                width_slider = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height_slider = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            steps_slider = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=50,
                step=1,
                value=25,
            )

        gr.Examples(
            examples=examples,
            fn=generate_image,
            inputs=[prompt],
            outputs=[result, seed_number],
            cache_examples="lazy",
        )

        def update_image_outputs(image_pil, seed_val):
            return {
                result: image_pil,
                seed_number: seed_val,
                download_btn: gr.DownloadButton(value=image_pil, visible=True)
                if image_pil
                else gr.DownloadButton(visible=False),
            }

        run_button.click(
            fn=generate_image,
            inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider],
            outputs=[result, seed_number],
        ).then(
            lambda img_path, vid_accordion, vid_btn: {
                vid_accordion: gr.Accordion(open=True),
                vid_btn: gr.Button(interactive=True),
            },
            inputs=[result],
            outputs=[],
        )

        video_result_output = gr.Video(label="Generated Video", show_label=False)

        with gr.Accordion("Video Generation from Image", open=False) as video_gen_accordion:
            video_prompt_input = gr.Text(
                label="Prompt for Video",
                placeholder="Describe the animation or changes for the video (e.g., 'camera zooms out slowly')",
                value="A gentle breeze rustles the leaves, subtle camera movement.",  # Default prompt
            )
            with gr.Row():
                video_duration_input = gr.Dropdown(label="Duration (seconds)", choices=["5", "10"], value="5")
                video_aspect_ratio_input = gr.Dropdown(
                    label="Aspect Ratio",
                    choices=["16:9", "9:16", "1:1"],
                    value="16:9",  # Default from API
                )
            video_negative_prompt_input = gr.Text(
                label="Negative Prompt for Video",
                value="blur, distort, low quality",  # Default from API
            )
            video_cfg_scale_input = gr.Slider(
                label="CFG Scale for Video",
                minimum=0.0,
                maximum=10.0,
                value=0.5,
                step=0.1,
            )
            generate_video_btn = gr.Button("Generate Video", interactive=False)

        generate_video_btn.click(
            fn=generate_video_from_image,
            inputs=[
                result,
                video_prompt_input,
                video_duration_input,
                video_aspect_ratio_input,
                video_negative_prompt_input,
                video_cfg_scale_input,
            ],
            outputs=[video_result_output],
        )

        run_button.click(
            fn=generate_image,
            inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider],
            outputs=[result, seed_number],
        ).then(
            lambda image_filepath: {
                video_gen_accordion: gr.Accordion(open=True),
                generate_video_btn: gr.Button(interactive=True if image_filepath else False),
                download_btn: gr.DownloadButton(value=image_filepath, visible=True if image_filepath else False),
            },
            inputs=[result],
            outputs=[video_gen_accordion, generate_video_btn, download_btn],
        )
        with gr.Accordion("Download Image from URL", open=False):
            image_url_input = gr.Text(label="Image URL", placeholder="Enter image URL (e.g., http://.../image.png)")
            filename_input = gr.Text(
                label="Filename (optional)",
                placeholder=" Filename",
            )
            download_from_url_btn = gr.DownloadButton(label="Download Image")

            download_from_url_btn.click(
                fn=download_locally,
                inputs=[image_url_input, filename_input],
                outputs=[download_from_url_btn],
            )


if __name__ == "__main__":
    demo.launch(mcp_server=True)