Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import re | |
import os | |
hf_token = os.environ.get('HF_TOKEN') | |
from gradio_client import Client, handle_file | |
clipi_client = Client("fffiloni/CLIP-Interrogator-2") | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_path = "meta-llama/Llama-2-7b-chat-hf" | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_token).half().cuda() | |
#client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token) | |
clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/") | |
def llama_gen_story(prompt): | |
"""Generate a fictional story using the LLaMA 2 model based on a prompt. | |
Args: | |
prompt: A string prompt containing an image description and story generation instructions. | |
Returns: | |
A generated fictional story string with special formatting and tokens removed. | |
""" | |
instruction = """[INST] <<SYS>>\nYou are a storyteller. You'll be given an image description and some keyword about the image. | |
For that given you'll be asked to generate a story that you think could fit very well with the image provided. | |
Always answer with a cool story, while being safe as possible. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]""" | |
prompt = instruction.format(prompt) | |
generate_ids = model.generate(tokenizer(prompt, return_tensors='pt').input_ids.cuda(), max_new_tokens=4096) | |
output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True) | |
#print(generate_ids) | |
#print(output_text) | |
pattern = r'\[INST\].*?\[/INST\]' | |
cleaned_text = re.sub(pattern, '', output_text, flags=re.DOTALL) | |
return cleaned_text | |
def get_text_after_colon(input_text): | |
# Find the first occurrence of ":" | |
colon_index = input_text.find(":") | |
# Check if ":" exists in the input_text | |
if colon_index != -1: | |
# Extract the text after the colon | |
result_text = input_text[colon_index + 1:].strip() | |
return result_text | |
else: | |
# Return the original text if ":" is not found | |
return input_text | |
def infer(image_input, audience): | |
"""Generate a fictional story based on an image using CLIP Interrogator and LLaMA2. | |
Args: | |
image_input: A file path to the input image to analyze. | |
audience: A string indicating the target audience, such as 'Children' or 'Adult'. | |
Returns: | |
A formatted, multi-paragraph fictional story string related to the image content. | |
Steps: | |
1. Use the CLIP Interrogator model to generate a semantic caption from the image. | |
2. Format a prompt asking the LLaMA2 model to write a story based on the caption. | |
3. Clean and format the story output for readability. | |
""" | |
gr.Info('Calling CLIP Interrogator ...') | |
clipi_result = clipi_client.predict( | |
image=handle_file(image_input), | |
mode="best", | |
best_max_flavors=4, | |
api_name="/clipi2" | |
) | |
print(clipi_result) | |
llama_q = f""" | |
I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story. | |
Here's the image description: | |
'{clipi_result}' | |
""" | |
gr.Info('Calling Llama2 ...') | |
result = llama_gen_story(llama_q) | |
print(f"Llama2 result: {result}") | |
result = get_text_after_colon(result) | |
# Split the text into paragraphs based on actual line breaks | |
paragraphs = result.split('\n') | |
# Join the paragraphs back with an extra empty line between each paragraph | |
formatted_text = '\n\n'.join(paragraphs) | |
return formatted_text | |
css=""" | |
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;} | |
div#story textarea { | |
font-size: 1.5em; | |
line-height: 1.4em; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center">Image to Story</h1> | |
<p style="text-align: center">Upload an image, get a story made by Llama2 !</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in") | |
audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children") | |
submit_btn = gr.Button('Tell me a story') | |
with gr.Column(): | |
#caption = gr.Textbox(label="Generated Caption") | |
story = gr.Textbox(label="generated Story", elem_id="story") | |
gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]], | |
fn=infer, | |
inputs=[image_in, audience], | |
outputs=[story], | |
cache_examples=False | |
) | |
submit_btn.click(fn=infer, inputs=[image_in, audience], outputs=[story]) | |
demo.queue(max_size=12).launch(ssr_mode=False, mcp_server=True) | |