QwenStoryteller / app.py
LPX55's picture
Update app.py
24db381 verified
import spaces
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
import os, time
# Load the model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"daniel3303/QwenStoryteller",
torch_dtype=torch.float16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")
@spaces.GPU()
@torch.no_grad()
def generate_story(file_paths, progress=gr.Progress(track_tqdm=True)):
# Load images from the file paths
images = [Image.open(file_path) for file_path in file_paths]
image_content = []
for img in images[:10]: # Limit to 6 images
image_content.append({
"type": "image",
"image": img,
})
image_content.append({"type": "text", "text": "Generate a story based on these images."})
messages = [
{
"role": "system",
"content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use 🧠 tags to show your reasoning process before writing the final story."
},
{
"role": "user",
"content": image_content,
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
story = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return story
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("# Qwen Storyteller \n## Upload up to 10 images to generate a creative story.")
with gr.Row():
with gr.Column():
upload_button = gr.UploadButton("Upload up to 10 images", file_types=["image"], file_count="multiple")
output_file = gr.File(label="Uploaded Files")
gen_button = gr.Button("Generate", variant="primary")
with gr.Column():
outputs = gr.Markdown(label="Generated Story", show_copy_button=True)
with gr.Row():
gr.Markdown(
"""
### Key Features
* Cross-Frame Consistency: Maintains consistent character and object identity across multiple frames through visual similarity and face recognition techniques
* Structured Reasoning: Employs chain-of-thought reasoning to analyze scenes with explicit modeling of characters, objects, settings, and narrative structure
* Grounded Storytelling: Uses specialized XML tags to link narrative elements directly to visual entities
* Reduced Hallucinations: Achieves 12.3% fewer hallucinations compared to the non-fine-tuned base model
Model trained by daniel3303, [repository here.](https://huggingface.co/daniel3303/QwenStoryteller)
"""
)
gr.Markdown(
"""
```
@misc{oliveira2025storyreasoningdatasetusingchainofthought,
title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation},
author={Daniel A. P. Oliveira and David Martins de Matos},
year={2025},
eprint={2505.10292},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2505.10292},
}
```
"""
)
upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file)
gen_button.click(generate_story, upload_button, outputs)
if __name__ == "__main__":
demo.queue().launch(show_error=True)