Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		zzzweakman
		
	commited on
		
		
					Commit 
							
							·
						
						0bc2c6f
	
1
								Parent(s):
							
							5379bd5
								
fix: retargeting feature leakage
Browse files- app.py +22 -11
- assets/gradio_description_retargeting.md +1 -1
- src/gradio_pipeline.py +22 -50
    	
        app.py
    CHANGED
    
    | @@ -72,7 +72,7 @@ data_examples = [ | |
| 72 | 
             
            # Define components first
         | 
| 73 | 
             
            eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
         | 
| 74 | 
             
            lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
         | 
| 75 | 
            -
            retargeting_input_image = gr.Image(type=" | 
| 76 | 
             
            output_image = gr.Image(type="numpy")
         | 
| 77 | 
             
            output_image_paste_back = gr.Image(type="numpy")
         | 
| 78 | 
             
            output_video = gr.Video()
         | 
| @@ -144,11 +144,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| 144 | 
             
                        examples_per_page=5,
         | 
| 145 | 
             
                        cache_examples=False,
         | 
| 146 | 
             
                    )
         | 
| 147 | 
            -
                gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible= | 
| 148 | 
            -
                with gr.Row(visible= | 
| 149 | 
             
                    eye_retargeting_slider.render()
         | 
| 150 | 
             
                    lip_retargeting_slider.render()
         | 
| 151 | 
            -
                with gr.Row(visible= | 
| 152 | 
             
                    process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
         | 
| 153 | 
             
                    process_button_reset_retargeting = gr.ClearButton(
         | 
| 154 | 
             
                        [
         | 
| @@ -160,10 +160,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| 160 | 
             
                        ],
         | 
| 161 | 
             
                        value="🧹 Clear"
         | 
| 162 | 
             
                    )
         | 
| 163 | 
            -
                with gr.Row(visible= | 
| 164 | 
             
                    with gr.Column():
         | 
| 165 | 
             
                        with gr.Accordion(open=True, label="Retargeting Input"):
         | 
| 166 | 
             
                            retargeting_input_image.render()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 167 | 
             
                    with gr.Column():
         | 
| 168 | 
             
                        with gr.Accordion(open=True, label="Retargeting Result"):
         | 
| 169 | 
             
                            output_image.render()
         | 
| @@ -174,7 +185,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| 174 | 
             
                process_button_retargeting.click(
         | 
| 175 | 
             
                    # fn=gradio_pipeline.execute_image,
         | 
| 176 | 
             
                    fn=gpu_wrapped_execute_image,
         | 
| 177 | 
            -
                    inputs=[eye_retargeting_slider, lip_retargeting_slider],
         | 
| 178 | 
             
                    outputs=[output_image, output_image_paste_back],
         | 
| 179 | 
             
                    show_progress=True
         | 
| 180 | 
             
                )
         | 
| @@ -190,11 +201,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| 190 | 
             
                    outputs=[output_video, output_video_concat],
         | 
| 191 | 
             
                    show_progress=True
         | 
| 192 | 
             
                )
         | 
| 193 | 
            -
                image_input.change(
         | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                )
         | 
| 198 | 
             
                video_input.upload(
         | 
| 199 | 
             
                    fn=is_square_video,
         | 
| 200 | 
             
                    inputs=video_input,
         | 
|  | |
| 72 | 
             
            # Define components first
         | 
| 73 | 
             
            eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
         | 
| 74 | 
             
            lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
         | 
| 75 | 
            +
            retargeting_input_image = gr.Image(type="filepath")
         | 
| 76 | 
             
            output_image = gr.Image(type="numpy")
         | 
| 77 | 
             
            output_image_paste_back = gr.Image(type="numpy")
         | 
| 78 | 
             
            output_video = gr.Video()
         | 
|  | |
| 144 | 
             
                        examples_per_page=5,
         | 
| 145 | 
             
                        cache_examples=False,
         | 
| 146 | 
             
                    )
         | 
| 147 | 
            +
                gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
         | 
| 148 | 
            +
                with gr.Row(visible=True):
         | 
| 149 | 
             
                    eye_retargeting_slider.render()
         | 
| 150 | 
             
                    lip_retargeting_slider.render()
         | 
| 151 | 
            +
                with gr.Row(visible=True):
         | 
| 152 | 
             
                    process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
         | 
| 153 | 
             
                    process_button_reset_retargeting = gr.ClearButton(
         | 
| 154 | 
             
                        [
         | 
|  | |
| 160 | 
             
                        ],
         | 
| 161 | 
             
                        value="🧹 Clear"
         | 
| 162 | 
             
                    )
         | 
| 163 | 
            +
                with gr.Row(visible=True):
         | 
| 164 | 
             
                    with gr.Column():
         | 
| 165 | 
             
                        with gr.Accordion(open=True, label="Retargeting Input"):
         | 
| 166 | 
             
                            retargeting_input_image.render()
         | 
| 167 | 
            +
                            gr.Examples(
         | 
| 168 | 
            +
                                examples=[
         | 
| 169 | 
            +
                                    [osp.join(example_portrait_dir, "s9.jpg")],
         | 
| 170 | 
            +
                                    [osp.join(example_portrait_dir, "s6.jpg")],
         | 
| 171 | 
            +
                                    [osp.join(example_portrait_dir, "s10.jpg")],
         | 
| 172 | 
            +
                                    [osp.join(example_portrait_dir, "s5.jpg")],
         | 
| 173 | 
            +
                                    [osp.join(example_portrait_dir, "s7.jpg")],
         | 
| 174 | 
            +
                                ],
         | 
| 175 | 
            +
                                inputs=[retargeting_input_image],
         | 
| 176 | 
            +
                                cache_examples=False,
         | 
| 177 | 
            +
                            )
         | 
| 178 | 
             
                    with gr.Column():
         | 
| 179 | 
             
                        with gr.Accordion(open=True, label="Retargeting Result"):
         | 
| 180 | 
             
                            output_image.render()
         | 
|  | |
| 185 | 
             
                process_button_retargeting.click(
         | 
| 186 | 
             
                    # fn=gradio_pipeline.execute_image,
         | 
| 187 | 
             
                    fn=gpu_wrapped_execute_image,
         | 
| 188 | 
            +
                    inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
         | 
| 189 | 
             
                    outputs=[output_image, output_image_paste_back],
         | 
| 190 | 
             
                    show_progress=True
         | 
| 191 | 
             
                )
         | 
|  | |
| 201 | 
             
                    outputs=[output_video, output_video_concat],
         | 
| 202 | 
             
                    show_progress=True
         | 
| 203 | 
             
                )
         | 
| 204 | 
            +
                # image_input.change(
         | 
| 205 | 
            +
                #     fn=gradio_pipeline.prepare_retargeting,
         | 
| 206 | 
            +
                #     inputs=image_input,
         | 
| 207 | 
            +
                #     outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
         | 
| 208 | 
            +
                # )
         | 
| 209 | 
             
                video_input.upload(
         | 
| 210 | 
             
                    fn=is_square_video,
         | 
| 211 | 
             
                    inputs=video_input,
         | 
    	
        assets/gradio_description_retargeting.md
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            <span style="font-size: 1.2em;">🔥 To change the  | 
|  | |
| 1 | 
            +
            <span style="font-size: 1.2em;">🔥 To change the eyes and lip open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the blocks. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
         | 
    	
        src/gradio_pipeline.py
    CHANGED
    
    | @@ -26,16 +26,6 @@ class GradioPipeline(LivePortraitPipeline): | |
| 26 | 
             
                    super().__init__(inference_cfg, crop_cfg)
         | 
| 27 | 
             
                    # self.live_portrait_wrapper = self.live_portrait_wrapper
         | 
| 28 | 
             
                    self.args = args
         | 
| 29 | 
            -
                    # for single image retargeting
         | 
| 30 | 
            -
                    self.start_prepare = False
         | 
| 31 | 
            -
                    self.f_s_user = None
         | 
| 32 | 
            -
                    self.x_c_s_info_user = None
         | 
| 33 | 
            -
                    self.x_s_user = None
         | 
| 34 | 
            -
                    self.source_lmk_user = None
         | 
| 35 | 
            -
                    self.mask_ori = None
         | 
| 36 | 
            -
                    self.img_rgb = None
         | 
| 37 | 
            -
                    self.crop_M_c2o = None
         | 
| 38 | 
            -
             | 
| 39 |  | 
| 40 | 
             
                def execute_video(
         | 
| 41 | 
             
                    self,
         | 
| @@ -66,30 +56,23 @@ class GradioPipeline(LivePortraitPipeline): | |
| 66 | 
             
                    else:
         | 
| 67 | 
             
                        raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
         | 
| 68 |  | 
| 69 | 
            -
                def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
         | 
| 70 | 
             
                    """ for single image retargeting
         | 
| 71 | 
             
                    """
         | 
|  | |
|  | |
|  | |
|  | |
| 72 | 
             
                    if input_eye_ratio is None or input_eye_ratio is None:
         | 
| 73 | 
             
                        raise gr.Error("Invalid ratio input 💥!", duration=5)
         | 
| 74 | 
            -
                    elif self.f_s_user is None:
         | 
| 75 | 
            -
                        if self.start_prepare:
         | 
| 76 | 
            -
                            raise gr.Error(
         | 
| 77 | 
            -
                                "The source portrait is under processing 💥! Please wait for a second.",
         | 
| 78 | 
            -
                                duration=5
         | 
| 79 | 
            -
                            )
         | 
| 80 | 
            -
                        else:
         | 
| 81 | 
            -
                            raise gr.Error(
         | 
| 82 | 
            -
                                "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
         | 
| 83 | 
            -
                                duration=5
         | 
| 84 | 
            -
                            )
         | 
| 85 | 
             
                    else:
         | 
| 86 | 
            -
                        x_s_user =  | 
| 87 | 
            -
                        f_s_user =  | 
| 88 | 
             
                        # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
         | 
| 89 | 
            -
                        combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]],  | 
| 90 | 
             
                        eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
         | 
| 91 | 
             
                        # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
         | 
| 92 | 
            -
                        combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]],  | 
| 93 | 
             
                        lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
         | 
| 94 | 
             
                        num_kp = x_s_user.shape[1]
         | 
| 95 | 
             
                        # default: use x_s
         | 
| @@ -97,21 +80,20 @@ class GradioPipeline(LivePortraitPipeline): | |
| 97 | 
             
                        # D(W(f_s; x_s, x′_d))
         | 
| 98 | 
             
                        out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
         | 
| 99 | 
             
                        out = self.live_portrait_wrapper.parse_output(out['out'])[0]
         | 
| 100 | 
            -
                        out_to_ori_blend = paste_back(out,  | 
| 101 | 
             
                        # gr.Info("Run successfully!", duration=2)
         | 
| 102 | 
             
                        return out, out_to_ori_blend
         | 
| 103 |  | 
| 104 |  | 
| 105 | 
            -
                def prepare_retargeting(self,  | 
| 106 | 
             
                    """ for single image retargeting
         | 
| 107 | 
             
                    """
         | 
| 108 | 
            -
                    if  | 
| 109 | 
             
                        # gr.Info("Upload successfully!", duration=2)
         | 
| 110 | 
            -
                        self.start_prepare = True
         | 
| 111 | 
             
                        inference_cfg = self.live_portrait_wrapper.cfg
         | 
| 112 | 
             
                        ######## process source portrait ########
         | 
| 113 | 
            -
                        img_rgb = load_img_online( | 
| 114 | 
            -
                        log(f"Load source image from { | 
| 115 | 
             
                        crop_info = self.cropper.crop_single_image(img_rgb)
         | 
| 116 | 
             
                        if flag_do_crop:
         | 
| 117 | 
             
                            I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
         | 
| @@ -120,23 +102,13 @@ class GradioPipeline(LivePortraitPipeline): | |
| 120 | 
             
                        x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
         | 
| 121 | 
             
                        R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
         | 
| 122 | 
             
                        ############################################
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                         | 
| 125 | 
            -
                         | 
| 126 | 
            -
                         | 
| 127 | 
            -
                         | 
| 128 | 
            -
                         | 
| 129 | 
            -
                        self.img_rgb = img_rgb
         | 
| 130 | 
            -
                        self.crop_M_c2o = crop_info['M_c2o']
         | 
| 131 | 
            -
                        self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
         | 
| 132 | 
            -
                        # update slider
         | 
| 133 | 
            -
                        eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
         | 
| 134 | 
            -
                        eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
         | 
| 135 | 
            -
                        lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
         | 
| 136 | 
            -
                        lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
         | 
| 137 | 
            -
                        # for vis
         | 
| 138 | 
            -
                        self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
         | 
| 139 | 
            -
                        return eye_close_ratio, lip_close_ratio, self.I_s_vis
         | 
| 140 | 
             
                    else:
         | 
| 141 | 
             
                        # when press the clear button, go here
         | 
| 142 | 
            -
                         | 
|  | 
|  | |
| 26 | 
             
                    super().__init__(inference_cfg, crop_cfg)
         | 
| 27 | 
             
                    # self.live_portrait_wrapper = self.live_portrait_wrapper
         | 
| 28 | 
             
                    self.args = args
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 |  | 
| 30 | 
             
                def execute_video(
         | 
| 31 | 
             
                    self,
         | 
|  | |
| 56 | 
             
                    else:
         | 
| 57 | 
             
                        raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
         | 
| 58 |  | 
| 59 | 
            +
                def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop = True):
         | 
| 60 | 
             
                    """ for single image retargeting
         | 
| 61 | 
             
                    """
         | 
| 62 | 
            +
                    # disposable feature
         | 
| 63 | 
            +
                    f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
         | 
| 64 | 
            +
                    self.prepare_retargeting(input_image, flag_do_crop)
         | 
| 65 | 
            +
             | 
| 66 | 
             
                    if input_eye_ratio is None or input_eye_ratio is None:
         | 
| 67 | 
             
                        raise gr.Error("Invalid ratio input 💥!", duration=5)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 | 
             
                    else:
         | 
| 69 | 
            +
                        x_s_user = x_s_user.to("cuda")
         | 
| 70 | 
            +
                        f_s_user = f_s_user.to("cuda")
         | 
| 71 | 
             
                        # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
         | 
| 72 | 
            +
                        combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
         | 
| 73 | 
             
                        eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
         | 
| 74 | 
             
                        # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
         | 
| 75 | 
            +
                        combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
         | 
| 76 | 
             
                        lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
         | 
| 77 | 
             
                        num_kp = x_s_user.shape[1]
         | 
| 78 | 
             
                        # default: use x_s
         | 
|  | |
| 80 | 
             
                        # D(W(f_s; x_s, x′_d))
         | 
| 81 | 
             
                        out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
         | 
| 82 | 
             
                        out = self.live_portrait_wrapper.parse_output(out['out'])[0]
         | 
| 83 | 
            +
                        out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
         | 
| 84 | 
             
                        # gr.Info("Run successfully!", duration=2)
         | 
| 85 | 
             
                        return out, out_to_ori_blend
         | 
| 86 |  | 
| 87 |  | 
| 88 | 
            +
                def prepare_retargeting(self, input_image, flag_do_crop = True):
         | 
| 89 | 
             
                    """ for single image retargeting
         | 
| 90 | 
             
                    """
         | 
| 91 | 
            +
                    if input_image is not None:
         | 
| 92 | 
             
                        # gr.Info("Upload successfully!", duration=2)
         | 
|  | |
| 93 | 
             
                        inference_cfg = self.live_portrait_wrapper.cfg
         | 
| 94 | 
             
                        ######## process source portrait ########
         | 
| 95 | 
            +
                        img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
         | 
| 96 | 
            +
                        log(f"Load source image from {input_image}.")
         | 
| 97 | 
             
                        crop_info = self.cropper.crop_single_image(img_rgb)
         | 
| 98 | 
             
                        if flag_do_crop:
         | 
| 99 | 
             
                            I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
         | 
|  | |
| 102 | 
             
                        x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
         | 
| 103 | 
             
                        R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
         | 
| 104 | 
             
                        ############################################
         | 
| 105 | 
            +
                        f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
         | 
| 106 | 
            +
                        x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
         | 
| 107 | 
            +
                        source_lmk_user = crop_info['lmk_crop']
         | 
| 108 | 
            +
                        crop_M_c2o = crop_info['M_c2o']
         | 
| 109 | 
            +
                        mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
         | 
| 110 | 
            +
                        return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
                    else:
         | 
| 112 | 
             
                        # when press the clear button, go here
         | 
| 113 | 
            +
                        raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)
         | 
| 114 | 
            +
             |