import spaces import gradio as gr from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info from PIL import Image import torch import os, time # Load the model and processor model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "daniel3303/QwenStoryteller", torch_dtype=torch.float16, device_map="auto" ) processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller") @spaces.GPU() @torch.no_grad() def generate_story(file_paths, progress=gr.Progress(track_tqdm=True)): # Load images from the file paths images = [Image.open(file_path) for file_path in file_paths] image_content = [] for img in images[:10]: # Limit to 6 images image_content.append({ "type": "image", "image": img, }) image_content.append({"type": "text", "text": "Generate a story based on these images."}) messages = [ { "role": "system", "content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use 🧠 tags to show your reasoning process before writing the final story." }, { "role": "user", "content": image_content, } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) inputs = inputs.to(model.device) generated_ids = model.generate( **inputs, max_new_tokens=4096, do_sample=True, temperature=0.7, top_p=0.9 ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] story = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return story with gr.Blocks(fill_height=True) as demo: gr.Markdown("# Qwen Storyteller \n## Upload up to 10 images to generate a creative story.") with gr.Row(): with gr.Column(): upload_button = gr.UploadButton("Upload up to 10 images", file_types=["image"], file_count="multiple") output_file = gr.File(label="Uploaded Files") gen_button = gr.Button("Generate", variant="primary") with gr.Column(): outputs = gr.Markdown(label="Generated Story", show_copy_button=True) with gr.Row(): gr.Markdown( """ ### Key Features * Cross-Frame Consistency: Maintains consistent character and object identity across multiple frames through visual similarity and face recognition techniques * Structured Reasoning: Employs chain-of-thought reasoning to analyze scenes with explicit modeling of characters, objects, settings, and narrative structure * Grounded Storytelling: Uses specialized XML tags to link narrative elements directly to visual entities * Reduced Hallucinations: Achieves 12.3% fewer hallucinations compared to the non-fine-tuned base model Model trained by daniel3303, [repository here.](https://huggingface.co/daniel3303/QwenStoryteller) """ ) gr.Markdown( """ ``` @misc{oliveira2025storyreasoningdatasetusingchainofthought, title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation}, author={Daniel A. P. Oliveira and David Martins de Matos}, year={2025}, eprint={2505.10292}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2505.10292}, } ``` """ ) upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file) gen_button.click(generate_story, upload_button, outputs) if __name__ == "__main__": demo.queue().launch(show_error=True)