File size: 5,591 Bytes
ac7ec65
56458a8
779f8e1
24c50a4
7175dd2
bc54443
 
 
 
40a70b3
 
 
 
 
 
 
 
07c23a5
7175dd2
dd17729
 
ac7ec65
40a70b3
5f7a8ac
 
 
 
 
 
 
 
40a70b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afc2656
 
 
 
 
 
 
 
 
 
 
 
56458a8
fbe3de0
5f7a8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
b84496f
bc54443
bdf4e0d
bc54443
 
 
 
522e040
fbe3de0
522e040
56458a8
7175dd2
15fa0e9
8616e8b
063a58d
7175dd2
 
b84496f
40a70b3
7175dd2
 
 
afc2656
 
9253ad3
 
2006d53
 
 
 
779f8e1
5a49966
56458a8
 
 
40a70b3
3696310
 
 
 
 
56458a8
 
 
 
 
 
e290e46
 
56458a8
 
8f36174
 
5703350
d18deff
22fd767
1b02803
20c0b53
d871142
cb86cb3
6ddd891
535e3fa
6ddd891
5a49966
0f6487e
535e3fa
cb86cb3
5a49966
56458a8
5f7a8ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import spaces
import gradio as gr
import re
import os 
hf_token = os.environ.get('HF_TOKEN')

from gradio_client import Client, handle_file

clipi_client = Client("fffiloni/CLIP-Interrogator-2")

from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_token).half().cuda()

#client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)

clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")

@spaces.GPU
def llama_gen_story(prompt):
    """Generate a fictional story using the LLaMA 2 model based on a prompt.
    
    Args:
        prompt: A string prompt containing an image description and story generation instructions.
        
    Returns:
        A generated fictional story string with special formatting and tokens removed.
    """

    instruction = """[INST] <<SYS>>\nYou are a storyteller. You'll be given an image description and some keyword about the image. 
            For that given you'll be asked to generate a story that you think could fit very well with the image provided.
            Always answer with a cool story, while being safe as possible.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]"""

    
    prompt = instruction.format(prompt)
    
    generate_ids = model.generate(tokenizer(prompt, return_tensors='pt').input_ids.cuda(), max_new_tokens=4096)
    output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
    #print(generate_ids)
    #print(output_text)
    pattern = r'\[INST\].*?\[/INST\]'
    cleaned_text = re.sub(pattern, '', output_text, flags=re.DOTALL)
    return cleaned_text

def get_text_after_colon(input_text):
    # Find the first occurrence of ":"
    colon_index = input_text.find(":")
    
    # Check if ":" exists in the input_text
    if colon_index != -1:
        # Extract the text after the colon
        result_text = input_text[colon_index + 1:].strip()
        return result_text
    else:
        # Return the original text if ":" is not found
        return input_text

def infer(image_input, audience):
    """Generate a fictional story based on an image using CLIP Interrogator and LLaMA2.
    
    Args:
        image_input: A file path to the input image to analyze.
        audience: A string indicating the target audience, such as 'Children' or 'Adult'.
    
    Returns:
        A formatted, multi-paragraph fictional story string related to the image content.
        
    Steps:
        1. Use the CLIP Interrogator model to generate a semantic caption from the image.
        2. Format a prompt asking the LLaMA2 model to write a story based on the caption.
        3. Clean and format the story output for readability.
    """
    gr.Info('Calling CLIP Interrogator ...')

    clipi_result = clipi_client.predict(
		image=handle_file(image_input),
		mode="best",
		best_max_flavors=4,
		api_name="/clipi2"
    )
    print(clipi_result)
   

    llama_q = f"""
    I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story. 
    Here's the image description: 
    '{clipi_result}'
    
    """
    gr.Info('Calling Llama2 ...')
    result = llama_gen_story(llama_q)

    print(f"Llama2 result: {result}")

    result = get_text_after_colon(result)

    # Split the text into paragraphs based on actual line breaks
    paragraphs = result.split('\n')
    
    # Join the paragraphs back with an extra empty line between each paragraph
    formatted_text = '\n\n'.join(paragraphs)


    return formatted_text

css="""
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;}


div#story textarea {
    font-size: 1.5em;
    line-height: 1.4em;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            <h1 style="text-align: center">Image to Story</h1>
            <p style="text-align: center">Upload an image, get a story made by Llama2 !</p>
            """
        )
        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in")
                audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children")
                submit_btn = gr.Button('Tell me a story')
            with gr.Column():
                #caption = gr.Textbox(label="Generated Caption")
                story = gr.Textbox(label="generated Story", elem_id="story")
        
        gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]],
                    fn=infer,
                    inputs=[image_in, audience],
                    outputs=[story],
                    cache_examples=False
                   )
        
    submit_btn.click(fn=infer, inputs=[image_in, audience], outputs=[story])

demo.queue(max_size=12).launch(ssr_mode=False, mcp_server=True)