File size: 4,486 Bytes
7e4ab32
754d2f6
24db381
754d2f6
 
 
c541413
754d2f6
 
24db381
754d2f6
24db381
754d2f6
 
 
54fce6e
7e4ab32
7fe0752
fb0d27d
7fe0752
 
 
754d2f6
24db381
754d2f6
 
 
 
 
 
 
 
 
 
24db381
754d2f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fe0752
754d2f6
 
 
 
 
 
 
 
 
 
 
0378032
 
24db381
7fe0752
f8c6098
 
24db381
7fe0752
8caf237
f8c6098
 
24db381
f8c6098
0378032
24db381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0378032
7fe0752
 
 
754d2f6
 
24db381
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import spaces
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
import os, time

# Load the model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "daniel3303/QwenStoryteller",
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")

@spaces.GPU()
@torch.no_grad()
def generate_story(file_paths, progress=gr.Progress(track_tqdm=True)):
    # Load images from the file paths
    images = [Image.open(file_path) for file_path in file_paths]
    
    image_content = []
    for img in images[:10]:  # Limit to 6 images
        image_content.append({
            "type": "image",
            "image": img,
        })

    image_content.append({"type": "text", "text": "Generate a story based on these images."})

    messages = [
        {
            "role": "system",
            "content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use 🧠 tags to show your reasoning process before writing the final story."
        },
        {
            "role": "user",
            "content": image_content,
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    )
    inputs = inputs.to(model.device)

    generated_ids = model.generate(
        **inputs,
        max_new_tokens=4096,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    story = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]

    return story


with gr.Blocks(fill_height=True) as demo:
    gr.Markdown("# Qwen Storyteller \n## Upload up to 10 images to generate a creative story.")
    
    with gr.Row():
        with gr.Column():
            upload_button = gr.UploadButton("Upload up to 10 images", file_types=["image"], file_count="multiple")
            output_file = gr.File(label="Uploaded Files")
            gen_button = gr.Button("Generate", variant="primary")

        with gr.Column():
            outputs = gr.Markdown(label="Generated Story", show_copy_button=True)

    with gr.Row():
        gr.Markdown(
            """
            ### Key Features
            * Cross-Frame Consistency: Maintains consistent character and object identity across multiple frames through visual similarity and face recognition techniques
            * Structured Reasoning: Employs chain-of-thought reasoning to analyze scenes with explicit modeling of characters, objects, settings, and narrative structure
            * Grounded Storytelling: Uses specialized XML tags to link narrative elements directly to visual entities
            * Reduced Hallucinations: Achieves 12.3% fewer hallucinations compared to the non-fine-tuned base model
            
            Model trained by daniel3303, [repository here.](https://huggingface.co/daniel3303/QwenStoryteller)
            """    
        )
        gr.Markdown(
            """
            ```
            @misc{oliveira2025storyreasoningdatasetusingchainofthought,
                  title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation}, 
                  author={Daniel A. P. Oliveira and David Martins de Matos},
                  year={2025},
                  eprint={2505.10292},
                  archivePrefix={arXiv},
                  primaryClass={cs.CV},
                  url={https://arxiv.org/abs/2505.10292}, 
            }
            ```
            """
        )
        
    upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file)

    gen_button.click(generate_story, upload_button, outputs)

if __name__ == "__main__":
    demo.queue().launch(show_error=True)