Spaces:
Running
Running
import os | |
import tempfile | |
import fal_client | |
import gradio as gr | |
import numpy as np | |
import requests | |
from dotenv import load_dotenv | |
from huggingface_hub import InferenceClient | |
load_dotenv() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
TOKEN = None | |
FAL_KEY = None | |
def download_locally(url: str, local_path: str = "downloaded_file.png") -> str: | |
"""Download an image or a video from a URL to a local path. | |
Args: | |
url (str): The URL of the image to download. Must be an http(s) URL. | |
local_path (str, optional): The path (including filename) where the file should be saved. Defaults to "downloaded_file.png". | |
Returns: | |
str: The filesystem path of the saved file – suitable for returning to a **gr.File** output, or as an MCP tool response. | |
""" | |
if local_path == "": | |
local_path = "downloaded_file.png" | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
# If the caller passed only a filename, save into a temporary directory to avoid permission issues | |
if os.path.dirname(local_path) == "": | |
tmp_dir = tempfile.gettempdir() | |
local_path = os.path.join(tmp_dir, local_path) | |
with open(local_path, "wb") as f: | |
f.write(response.content) | |
return local_path | |
def login_hf(oauth_token: gr.OAuthToken | None): | |
""" | |
Login to Hugging Face and check initial key statuses. | |
Args: | |
oauth_token (gr.OAuthToken | None): The OAuth token from Hugging Face. | |
""" | |
global TOKEN | |
if oauth_token and oauth_token.token: | |
print("Received OAuth token, logging in...") | |
TOKEN = oauth_token.token | |
else: | |
print("No OAuth token provided, using environment variable HF_TOKEN.") | |
TOKEN = os.environ.get("HF_TOKEN") | |
print("TOKEN: ", TOKEN) | |
def login_fal(fal_key_from_ui: str | None): | |
""" | |
Sets the FAL API key from the UI. | |
Args: | |
fal_key_from_ui (str | None): The FAL key from the UI textbox. | |
""" | |
global FAL_KEY | |
if fal_key_from_ui and fal_key_from_ui.strip(): | |
FAL_KEY = fal_key_from_ui.strip() | |
os.environ["FAL_KEY"] = FAL_KEY | |
print("FAL_KEY has been set from UI input.") | |
else: | |
FAL_KEY = os.environ.get("FAL_KEY") | |
print("FAL_KEY is configured from environment variable.") | |
print("FAL_KEY: ", FAL_KEY) | |
def generate_image(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024, num_inference_steps: int = 25): | |
""" | |
Generate an image from a prompt. | |
Args: | |
prompt (str): | |
The prompt to generate an image from. | |
seed (int, default=42): | |
Seed for the random number generator. | |
height (int, default=1024): | |
The height in pixels of the output image | |
width (int, default=1024): | |
The width in pixels of the output image | |
num_inference_steps (int, default=25): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
""" | |
client = InferenceClient(provider="fal-ai", token=TOKEN) | |
image = client.text_to_image( | |
prompt=prompt, | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
seed=seed, | |
model="black-forest-labs/FLUX.1-dev", | |
) | |
return image, seed | |
def generate_video_from_image( | |
image_filepath: str, # This will be the path to the image from gr.Image output | |
video_prompt: str, | |
duration: str, # "5" or "10" | |
aspect_ratio: str, # "16:9", "9:16", "1:1" | |
video_negative_prompt: str, | |
cfg_scale_video: float, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
""" | |
Generates a video from an image using fal-ai/kling-video API. | |
""" | |
if not FAL_KEY: | |
gr.Error("FAL_KEY is not set. Cannot generate video.") | |
return None | |
if not image_filepath: | |
gr.Warning("No image provided to generate video from.") | |
return None | |
if not os.path.exists(image_filepath): | |
gr.Error(f"Image file not found at: {image_filepath}") | |
return None | |
print(f"Video generation started for image: {image_filepath}") | |
progress(0, desc="Preparing for video generation...") | |
try: | |
progress(0.1, desc="Uploading image...") | |
print("Uploading image to fal.ai storage...") | |
print("FAL_KEY: ", os.environ.get("FAL_KEY")) | |
image_url = fal_client.upload_file(image_filepath) | |
print(f"Image uploaded, URL: {image_url}") | |
progress(0.3, desc="Image uploaded. Submitting video request...") | |
def on_queue_update(update): | |
if isinstance(update, fal_client.InProgress): | |
if update.logs: | |
for log in update.logs: | |
print(f"[fal-ai log] {log['message']}") | |
# Try to update progress description if logs are available | |
# progress(progress.current_progress_value, desc=f"Video processing: {log['message'][:50]}...") | |
print("Subscribing to fal-ai/kling-video/v2.1/master/image-to-video...") | |
api_result = fal_client.subscribe( | |
"fal-ai/kling-video/v2.1/master/image-to-video", | |
arguments={ | |
"prompt": video_prompt, | |
"image_url": image_url, | |
"duration": duration, | |
"aspect_ratio": aspect_ratio, | |
"negative_prompt": video_negative_prompt, | |
"cfg_scale": cfg_scale_video, | |
}, | |
with_logs=True, # Get logs | |
on_queue_update=on_queue_update, # Callback for logs | |
) | |
progress(0.9, desc="Video processing complete.") | |
video_output_url = api_result.get("video", {}).get("url") | |
if video_output_url: | |
print(f"Video generated successfully: {video_output_url}") | |
progress(1, desc="Video ready!") | |
return video_output_url | |
else: | |
print(f"Video generation failed or no URL in response. API Result: {api_result}") | |
gr.Error("Video generation failed or no video URL returned.") | |
return None | |
except Exception as e: | |
print(f"Error during video generation: {e}") | |
gr.Error(f"An error occurred: {str(e)}") | |
return None | |
examples = [ | |
"a tiny astronaut hatching from an egg on the moon", | |
"a cat holding a sign that says hello world", | |
"an anime illustration of a wiener schnitzel", | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
demo.load(login_hf, inputs=None, outputs=None) | |
demo.load(login_fal, inputs=None, outputs=None) | |
with gr.Sidebar(): | |
gr.Markdown("# Authentication") | |
gr.Markdown( | |
"Sign in with Hugging Face for image generation. Separately, set your fal.ai API Key for image to video generation." | |
) | |
gr.Markdown("### Hugging Face Login") | |
hf_login_button = gr.LoginButton("Sign in with Hugging Face") | |
# When hf_login_button is clicked, it provides an OAuthToken or None to the login function. | |
hf_login_button.click(fn=login_hf, inputs=[hf_login_button], outputs=None) | |
gr.Markdown("### FAL Login (for Image to Video)") | |
fal_key_input = gr.Textbox( | |
label="FAL API Key", | |
placeholder="Enter your FAL API Key here", | |
type="password", | |
value=os.environ.get("FAL_KEY", ""), # Pre-fill if loaded from env | |
) | |
set_fal_key_button = gr.Button("Set FAL Key") | |
set_fal_key_button.click(fn=login_fal, inputs=[fal_key_input], outputs=None) | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
"""# Text to Image to Video with fal‑ai through HF Inference Providers ⚡\nLearn more about HF Inference Providers [here](https://huggingface.co/docs/inference-providers/index)""" | |
"""## Text to Image uses [FLUX.1 [dev]](https://fal.ai/models/fal-ai/flux/dev) with fal‑ai through HF Inference Providers""" | |
"""## Image to Vide uses [kling-video v2.1](https://fal.ai/models/fal-ai/kling-video/v2.1/master/image-to-video/playground) with fal‑ai directly (you will need to set your `FAL_KEY`).""" | |
) | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Generated Image", show_label=False, format="png", type="filepath") | |
download_btn = gr.DownloadButton( | |
label="Download result image", | |
visible=False, | |
value=None, | |
variant="primary", | |
) | |
seed_number = gr.Number(label="Seed", precision=0, value=42, interactive=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed_slider = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
) | |
with gr.Row(): | |
width_slider = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height_slider = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
steps_slider = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
gr.Examples( | |
examples=examples, | |
fn=generate_image, | |
inputs=[prompt], | |
outputs=[result, seed_number], | |
cache_examples="lazy", | |
) | |
def update_image_outputs(image_pil, seed_val): | |
return { | |
result: image_pil, | |
seed_number: seed_val, | |
download_btn: gr.DownloadButton(value=image_pil, visible=True) | |
if image_pil | |
else gr.DownloadButton(visible=False), | |
} | |
run_button.click( | |
fn=generate_image, | |
inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider], | |
outputs=[result, seed_number], | |
).then( | |
lambda img_path, vid_accordion, vid_btn: { | |
vid_accordion: gr.Accordion(open=True), | |
vid_btn: gr.Button(interactive=True), | |
}, | |
inputs=[result], | |
outputs=[], | |
) | |
video_result_output = gr.Video(label="Generated Video", show_label=False) | |
with gr.Accordion("Video Generation from Image", open=False) as video_gen_accordion: | |
video_prompt_input = gr.Text( | |
label="Prompt for Video", | |
placeholder="Describe the animation or changes for the video (e.g., 'camera zooms out slowly')", | |
value="A gentle breeze rustles the leaves, subtle camera movement.", # Default prompt | |
) | |
with gr.Row(): | |
video_duration_input = gr.Dropdown(label="Duration (seconds)", choices=["5", "10"], value="5") | |
video_aspect_ratio_input = gr.Dropdown( | |
label="Aspect Ratio", | |
choices=["16:9", "9:16", "1:1"], | |
value="16:9", # Default from API | |
) | |
video_negative_prompt_input = gr.Text( | |
label="Negative Prompt for Video", | |
value="blur, distort, low quality", # Default from API | |
) | |
video_cfg_scale_input = gr.Slider( | |
label="CFG Scale for Video", | |
minimum=0.0, | |
maximum=10.0, | |
value=0.5, | |
step=0.1, | |
) | |
generate_video_btn = gr.Button("Generate Video", interactive=False) | |
generate_video_btn.click( | |
fn=generate_video_from_image, | |
inputs=[ | |
result, | |
video_prompt_input, | |
video_duration_input, | |
video_aspect_ratio_input, | |
video_negative_prompt_input, | |
video_cfg_scale_input, | |
], | |
outputs=[video_result_output], | |
) | |
run_button.click( | |
fn=generate_image, | |
inputs=[prompt, seed_slider, width_slider, height_slider, steps_slider], | |
outputs=[result, seed_number], | |
).then( | |
lambda image_filepath: { | |
video_gen_accordion: gr.Accordion(open=True), | |
generate_video_btn: gr.Button(interactive=True if image_filepath else False), | |
download_btn: gr.DownloadButton(value=image_filepath, visible=True if image_filepath else False), | |
}, | |
inputs=[result], | |
outputs=[video_gen_accordion, generate_video_btn, download_btn], | |
) | |
with gr.Accordion("Download Image from URL", open=False): | |
image_url_input = gr.Text(label="Image URL", placeholder="Enter image URL (e.g., http://.../image.png)") | |
filename_input = gr.Text( | |
label="Filename (optional)", | |
placeholder=" Filename", | |
) | |
download_from_url_btn = gr.DownloadButton(label="Download Image") | |
download_from_url_btn.click( | |
fn=download_locally, | |
inputs=[image_url_input, filename_input], | |
outputs=[download_from_url_btn], | |
) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |