File size: 8,274 Bytes
6d95ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fddd61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d95ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import sys
import torch.nn.functional as F
import numpy as np

import utils.loss
import utils.samp
import utils.data
import utils.improc
import utils.misc
import utils.saveload
import cv2
from nets.blocks import InputPadder

import torch
from PIL import Image, ImageDraw
import numpy as np

torch.set_float32_matmul_precision('medium')

def run_example(processor, model, task_prompt, image, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float32)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      num_beams=3
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

    parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))

    return parsed_answer


def polygons_to_mask(image, prediction, fill_value=255):
    """
    Converts polygons into a mask.
    
    Parameters:
    - image: A PIL Image instance whose size will be used for the mask.
    - prediction: Dictionary containing 'polygons' and 'labels'. 
                  'polygons' is a list where each element is a list of sub-polygons.
    - fill_value: The pixel value used to fill the polygon areas (default 255 for a binary mask).
    
    Returns:
    - A NumPy array representing the mask (same width and height as the input image).
    """
    # Create a blank grayscale mask image with the same size as the original image.
    mask = Image.new('L', image.size, 0)
    draw = ImageDraw.Draw(mask)
    
    # Iterate over each set of polygons
    for polygons in prediction['polygons']:
        # Each element in "polygons" can be a sub-polygon
        for poly in polygons:
            # Ensure the polygon is in the right shape and has at least 3 points.
            poly_arr = np.array(poly).reshape(-1, 2)
            if poly_arr.shape[0] < 3:
                print('Skipping invalid polygon:', poly_arr)
                continue
            # Convert the polygon vertices into a list for drawing.
            poly_list = poly_arr.reshape(-1).tolist()
            # Draw the polygon on the mask with the fill_value.
            draw.polygon(poly_list, fill=fill_value)
    
    # Convert the PIL mask image to a NumPy array and return it.
    return np.array(mask)


class Tracker:
    def __init__(self, model, mean, std, S, stride, inference_iters, target_res, device='cuda'):
        """
        Initializes the Tracker.

        Args:
            model: The model used to compute feature maps and forward window flow.
            mean: Tensor or value used for normalizing the input.
            std: Tensor or value used for normalizing the input.
            S: Window size for the tracker.
            stride: The stride used when updating the window.
            inference_iters: Number of inference iterations.
            device: Torch device, defaults to 'cuda'.
        """
        self.model = model
        self.mean = mean
        self.std = std
        self.S = S
        self.stride = stride
        self.inference_iters = inference_iters
        self.device = device
        self.target_res = target_res

        self.padder = None
        self.cnt = 0
        self.fmap_anchor = None
        self.fmaps2 = None
        self.flows8 = None
        self.visconfs8 = None
        self.flows = []  # List to store computed flows
        self.visibs = []  # List to store visibility confidences
        self.rgbs = []  # List to store RGB frames

    def reset(self):
        """Reset the tracker state."""
        self.padder = None
        self.cnt = 0
        self.fmap_anchor = None
        self.fmaps2 = None
        self.flows8 = None
        self.visconfs8 = None
        self.flows = []
        self.visibs = []
        self.rgbs = []
        
    def preprocess(self, rgb_frame):
        # Resize frame (scale to keep maximum dimension ~1024)
        scale = min(self.target_res / rgb_frame.shape[0], self.target_res / rgb_frame.shape[1])
        rgb_resized = cv2.resize(rgb_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        
        # Convert to tensor, normalize and move to device.
        rgb_tensor = torch.from_numpy(rgb_resized).permute(2, 0, 1).float().unsqueeze(0).to(self.device)
        rgb_tensor = rgb_tensor / 255.0
        
        self.rgbs.append(rgb_tensor.cpu())
            
        # import pdb; pdb.set_trace()
        rgb_tensor = (rgb_tensor - self.mean) / self.std
        return rgb_tensor

    @torch.no_grad()
    def track(self, rgb_frame):
        """
        Process a single RGB frame and return the computed flow when available.

        Args:
            rgb_frame: A NumPy array containing the RGB frame.
                       (Assumed to be in RGB; if coming from OpenCV, convert it before passing.)
        
        Returns:
            flow_predictions: The predicted flow for the current frame (or None if not enough frames have been processed).
        """
        torch.cuda.empty_cache()

        rgb_tensor = self.preprocess(rgb_frame)
        
        # Initialize padder on the first frame.
        if self.cnt == 0:
            self.padder = InputPadder(rgb_tensor.shape)
        rgb_padded = self.padder.pad(rgb_tensor)[0]
        _, _, H_pad, W_pad = rgb_padded.shape
        C = 256  # Feature map channel dimension (could be parameterized if needed)
        H8, W8 = H_pad // 8, W_pad // 8

        # Accumulate feature maps until the window is full.
        if self.cnt == 0:
            self.fmap_anchor = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, C, H8, W8)
            self.fmaps2 = self.fmap_anchor[:, None]
            self.cnt += 1
            return None
            
        new_fmap = self.model.get_fmaps(rgb_padded, 1, 1, None, False, False).reshape(1, 1, C, H8, W8)
        self.fmaps2 = torch.cat([self.fmaps2[:, (1 if self.fmaps2.shape[1] >= self.S else 0):].detach().clone(), new_fmap], dim=1)
        
        # need to track
        if self.cnt - self.S + 1 >= 0 and (self.cnt - self.S + 1) % self.stride == 0:
            # Initialize or update temporary flow buffers.
            iter_num = self.inference_iters
            if self.flows8 is None:
                self.flows8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
                self.visconfs8 = torch.zeros((self.S, 2, H_pad // 8, W_pad // 8), device=self.device)
                # iter_num = self.inference_iters
            else:
                self.flows8 = torch.cat([
                    self.flows8[self.stride:self.stride + self.S // 2].detach().clone(),
                    self.flows8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
                ])
                self.visconfs8 = torch.cat([
                    self.visconfs8[self.stride:self.stride + self.S // 2].detach().clone(),
                    self.visconfs8[self.stride + self.S // 2 - 1:self.stride + self.S // 2].detach().clone().repeat(self.S // 2, 1, 1, 1)
                ])

            # import pdb; pdb.set_trace()
            # Compute flow predictions using the model's forward window.
            flow_predictions, visconf_predictions, self.flows8, self.visconfs8, _ = self.model.forward_window(
                self.fmap_anchor,
                self.fmaps2,
                self.visconfs8,
                iters=iter_num,
                flowfeat=None,
                flows8=self.flows8,
                is_training=False
            )
            flow_predictions = self.padder.unpad(flow_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:])
            visconf_predictions = self.padder.unpad(torch.sigmoid(visconf_predictions[-1][0 if self.cnt == self.S - 1 else -self.stride:]))

            self.cnt += 1
            self.flows.append(flow_predictions.cpu())
            self.visibs.append(visconf_predictions.cpu())

            return flow_predictions, visconf_predictions
        
        self.cnt += 1
        return None