Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| from PIL import Image | |
| import torch | |
| import os, time | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load the model and processor | |
| # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| # "daniel3303/QwenStoryteller", | |
| # torch_dtype=torch.float16, | |
| # device_map="auto" | |
| # ) | |
| # processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller") | |
| # Load the model and processor | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| "daniel3303/QwenStoryteller2", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller2") | |
| def generate_story(file_paths, progress=gr.Progress(track_tqdm=True)): | |
| # Load images from the file paths | |
| images = [Image.open(file_path) for file_path in file_paths] | |
| image_content = [] | |
| for img in images[:10]: # Limit to 6 images | |
| image_content.append({ | |
| "type": "image", | |
| "image": img, | |
| }) | |
| image_content.append({"type": "text", "text": "Generate a story based on these images."}) | |
| system = """You are an AI storyteller that can analyze sequences of images and create creative narratives. | |
| First think step-by-step to analyze characters, objects, settings, and narrative structure. | |
| Then create a grounded story that maintains consistent character identity and object references across frames. | |
| Use <think></think> tags to show your reasoning process before writing the final story. | |
| """ | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system | |
| }, | |
| { | |
| "role": "user", | |
| "content": image_content, | |
| } | |
| ] | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = inputs.to(model.device) | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| story = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| return story | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown("# Qwen Storyteller \n## Upload up to 10 images to generate a creative story.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| upload_button = gr.UploadButton("Upload up to 10 images", file_types=["image"], file_count="multiple") | |
| output_file = gr.File(label="Uploaded Files") | |
| gen_button = gr.Button("Generate", variant="primary") | |
| gr.Markdown( | |
| """ | |
| \n\n | |
| ``` | |
| @misc{oliveira2025storyreasoningdatasetusingchainofthought, | |
| title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation}, | |
| author={Daniel A. P. Oliveira and David Martins de Matos}, | |
| year={2025}, | |
| eprint={2505.10292}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV}, | |
| url={https://arxiv.org/abs/2505.10292}, | |
| } | |
| ``` | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| outputs = gr.Markdown(label="Generated Story", show_copy_button=True) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| ### Key Features | |
| * Cross-Frame Consistency: Maintains consistent character and object identity across multiple frames through visual similarity and face recognition techniques | |
| * Structured Reasoning: Employs chain-of-thought reasoning to analyze scenes with explicit modeling of characters, objects, settings, and narrative structure | |
| * Grounded Storytelling: Uses specialized XML tags to link narrative elements directly to visual entities | |
| * Reduced Hallucinations: Achieves 12.3% fewer hallucinations compared to the non-fine-tuned base model | |
| Model trained by daniel3303, [repository here.](https://huggingface.co/daniel3303/QwenStoryteller2) | |
| """ | |
| ) | |
| upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file) | |
| gen_button.click(generate_story, upload_button, outputs) | |
| if __name__ == "__main__": | |
| demo.queue().launch(show_error=True) |