Spaces:
Runtime error
Runtime error
File size: 4,486 Bytes
7e4ab32 754d2f6 24db381 754d2f6 c541413 754d2f6 24db381 754d2f6 24db381 754d2f6 54fce6e 7e4ab32 7fe0752 fb0d27d 7fe0752 754d2f6 24db381 754d2f6 24db381 754d2f6 7fe0752 754d2f6 0378032 24db381 7fe0752 f8c6098 24db381 7fe0752 8caf237 f8c6098 24db381 f8c6098 0378032 24db381 0378032 7fe0752 754d2f6 24db381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import spaces
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
import os, time
# Load the model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"daniel3303/QwenStoryteller",
torch_dtype=torch.float16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")
@spaces.GPU()
@torch.no_grad()
def generate_story(file_paths, progress=gr.Progress(track_tqdm=True)):
# Load images from the file paths
images = [Image.open(file_path) for file_path in file_paths]
image_content = []
for img in images[:10]: # Limit to 6 images
image_content.append({
"type": "image",
"image": img,
})
image_content.append({"type": "text", "text": "Generate a story based on these images."})
messages = [
{
"role": "system",
"content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use 🧠 tags to show your reasoning process before writing the final story."
},
{
"role": "user",
"content": image_content,
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
story = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return story
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("# Qwen Storyteller \n## Upload up to 10 images to generate a creative story.")
with gr.Row():
with gr.Column():
upload_button = gr.UploadButton("Upload up to 10 images", file_types=["image"], file_count="multiple")
output_file = gr.File(label="Uploaded Files")
gen_button = gr.Button("Generate", variant="primary")
with gr.Column():
outputs = gr.Markdown(label="Generated Story", show_copy_button=True)
with gr.Row():
gr.Markdown(
"""
### Key Features
* Cross-Frame Consistency: Maintains consistent character and object identity across multiple frames through visual similarity and face recognition techniques
* Structured Reasoning: Employs chain-of-thought reasoning to analyze scenes with explicit modeling of characters, objects, settings, and narrative structure
* Grounded Storytelling: Uses specialized XML tags to link narrative elements directly to visual entities
* Reduced Hallucinations: Achieves 12.3% fewer hallucinations compared to the non-fine-tuned base model
Model trained by daniel3303, [repository here.](https://huggingface.co/daniel3303/QwenStoryteller)
"""
)
gr.Markdown(
"""
```
@misc{oliveira2025storyreasoningdatasetusingchainofthought,
title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation},
author={Daniel A. P. Oliveira and David Martins de Matos},
year={2025},
eprint={2505.10292},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2505.10292},
}
```
"""
)
upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file)
gen_button.click(generate_story, upload_button, outputs)
if __name__ == "__main__":
demo.queue().launch(show_error=True) |