File size: 14,466 Bytes
05fcd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
import os
import numpy as np
import math
import decord
from tqdm import tqdm
import pathlib
from PIL import Image

from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.memory import DynamicSwapInstaller
from diffusers_helper.utils import resize_and_center_crop
from diffusers_helper.bucket_tools import find_nearest_bucket
from diffusers_helper.hunyuan import vae_encode, vae_decode
from .video_base_generator import VideoBaseModelGenerator

class VideoModelGenerator(VideoBaseModelGenerator):
    """
    Generator for the Video (backward) extension of the Original HunyuanVideo model.
    These generators accept video input instead of a single image.
    """
    
    def __init__(self, **kwargs):
        """
        Initialize the Video model generator.
        """
        super().__init__(**kwargs)
        self.model_name = "Video"
        self.model_path = 'lllyasviel/FramePackI2V_HY'  # Same as Original
        self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
    
    def get_latent_paddings(self, total_latent_sections):
        """
        Get the latent paddings for the Video model.
        
        Args:
            total_latent_sections: The total number of latent sections
            
        Returns:
            A list of latent paddings
        """
        # Video model uses reversed latent paddings like Original
        if total_latent_sections > 4:
            return [3] + [2] * (total_latent_sections - 3) + [1, 0]
        else:
            return list(reversed(range(total_latent_sections)))

    def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
        """
        Combined method to prepare clean latents and indices for the Video model.
        
        Args:
            Work in progress - better not to pass in latent_paddings and latent_padding.
            num_cleaned_frames: Number of context frames to use from the video (adherence to video)
            
        Returns:
            A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
        """
        # Get num_cleaned_frames from job_params if available, otherwise use default value of 5
        num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5


        # HACK SOME STUFF IN THAT SHOULD NOT BE HERE
        # Placeholders for end frame processing
        # Colin, I'm only leaving them for the moment in case you want separate models for
        # Video-backward and Video-backward-Endframe.
        # end_latent = None
        # end_of_input_video_embedding = None # Placeholder for end frame's CLIP embedding. SEE: 20250507 pftq: Process end frame if provided
        # end_clip_embedding = None # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
        # end_frame_weight = 0.0 # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
        # HACK MORE STUFF IN THAT PROBABLY SHOULD BE ARGUMENTS OR OTHWISE MADE AVAILABLE
        end_of_input_video_latent = video_latents[:, :, -1:] # Last frame of the input video (produced by video_encode in the PR)
        is_start_of_video = latent_padding == 0 # This refers to the start of the *generated* video part
        is_end_of_video = latent_padding == latent_paddings[0] # This refers to the end of the *generated* video part (closest to input video) (better not to pass in latent_paddings[])
        # End of HACK STUFF

        # Dynamic frame allocation for context frames (clean latents)
        # This determines which frames from history_latents are used as input for the transformer.
        available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] # Use input video frames for first segment, else previously generated history
        effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
        if is_start_of_video:
            effective_clean_frames = 1 # Avoid jumpcuts if input video is too different

        clean_latent_pre_frames = effective_clean_frames 
        num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1
        num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1
        total_context_frames = num_2x_frames + num_4x_frames
        total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames)

        # Prepare indices for the transformer's input (these define the *relative positions* of different frame types in the input tensor)
        # The total length is the sum of various frame types:
        # clean_latent_pre_frames: frames before the blank/generated section
        # latent_padding_size: blank frames before the generated section (for backward generation)
        # latent_window_size: the new frames to be generated
        # post_frames: frames after the generated section
        # num_2x_frames, num_4x_frames: frames for lower resolution context
        # 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post
        post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames  # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image
        indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0)
        clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split(
            [clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1
        )
        clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) # Combined indices for 1x clean latents

        # Prepare the *actual latent data* for the transformer's context inputs
        # These are extracted from history_latents (or video_latents for the first segment)
        context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :]
        # clean_latents_4x: 4x downsampled context frames. From history_latents (or video_latents).
        # clean_latents_2x: 2x downsampled context frames. From history_latents (or video_latents).
        split_sizes = [num_4x_frames, num_2x_frames]
        split_sizes = [s for s in split_sizes if s > 0]
        if split_sizes and context_frames.shape[2] >= sum(split_sizes):
            splits = context_frames.split(split_sizes, dim=2)
            split_idx = 0
            clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :]
            split_idx += 1 if num_4x_frames > 0 else 0
            clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :]
        else:
            clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :]

        # clean_latents_pre: Latents from the *end* of the input video (if is_start_of_video), or previously generated history.
        # Its purpose is to provide a smooth transition *from* the input video.
        clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents)
        
        # clean_latents_post: Latents from the *beginning* of the previously generated video segments.
        # Its purpose is to provide a smooth transition *to* the existing generated content.
        clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :]

        # Special handling for the end frame:
        # If it's the very first segment being generated (is_end_of_video in terms of generation order),
        # and an end_latent was provided, force clean_latents_post to be that end_latent.
        if is_end_of_video:
            clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) # Initialize to zero
        
        # RT_BORG: end_of_input_video_embedding and end_clip_embedding shouldn't need to be checked, since they should
        # always be provided if end_latent is provided. But bulletproofing before the release since test time will be short.
        if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None:
            # image_encoder_last_hidden_state: Weighted average of CLIP embedding of first input frame and end frame's CLIP embedding
            # This guides the overall content to transition towards the end frame.
            image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight
            image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype)
            
            if is_end_of_video:
                # For the very first generated segment, the "post" part is the end_latent itself.
                clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame
            
        # Pad clean_latents_pre/post if they have fewer frames than specified by clean_latent_pre_frames/post_frames
        if clean_latents_pre.shape[2] < clean_latent_pre_frames:
            clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames]
        if clean_latents_post.shape[2] < post_frames:
            clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames]
            
        # clean_latents: Concatenation of pre and post clean latents. These are the 1x resolution context frames.
        clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)

        return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
    
    def update_history_latents(self, history_latents, generated_latents):
        """
        Backward Generation: Update the history latents with the generated latents for the Video model.
        
        Args:
            history_latents: The history latents
            generated_latents: The generated latents
            
        Returns:
            The updated history latents
        """
        # For Video model, we prepend the generated latents to the front of history latents
        # This matches the original implementation in video-example.py
        # It generates new sections backwards in time, chunk by chunk
        return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
    
    def get_real_history_latents(self, history_latents, total_generated_latent_frames):
        """
        Get the real history latents for the backward Video model. For Video, this is the first
        `total_generated_latent_frames` frames of the history latents.
        
        Args:
            history_latents: The history latents
            total_generated_latent_frames: The total number of generated latent frames
            
        Returns:
            The real history latents
        """
        # Generated frames at the front.
        return history_latents[:, :, :total_generated_latent_frames, :, :]
    
    def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
        """
        Update the history pixels with the current pixels for the Video model.
        
        Args:
            history_pixels: The history pixels
            current_pixels: The current pixels
            overlapped_frames: The number of overlapped frames
            
        Returns:
            The updated history pixels
        """
        from diffusers_helper.utils import soft_append_bcthw
        # For Video model, we prepend the current pixels to the history pixels
        # This matches the original implementation in video-example.py
        return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
    
    def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
        """
        Get the current pixels for the Video model.
        
        Args:
            real_history_latents: The real history latents
            section_latent_frames: The number of section latent frames
            vae: The VAE model
            
        Returns:
            The current pixels
        """
        # For backward Video mode, current pixels are at the front of history.
        return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
    
    def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
        """
        Format the position description for the Video model.
        
        Args:
            total_generated_latent_frames: The total number of generated latent frames
            current_pos: The current position in seconds (includes input video time)
            original_pos: The original position in seconds
            current_prompt: The current prompt
            
        Returns:
            The formatted position description
        """
        # For Video model, current_pos already includes the input video time
        # We just need to display the total generated frames and the current position
        return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
                f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
                f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). '
                f'using prompt: {current_prompt[:256]}...')