import json import os import cv2 import gradio as gr import imagehash import numpy as np import plotly.graph_objects as go from gradio_imageslider import ImageSlider from PIL import Image from scipy.stats import pearsonr from skimage.metrics import mean_squared_error as mse_skimage from skimage.metrics import peak_signal_noise_ratio as psnr_skimage from skimage.metrics import structural_similarity as ssim class FrameMetrics: """Class to compute and store frame-by-frame metrics""" def __init__(self): self.metrics = {} def compute_ssim(self, frame1, frame2): """Compute SSIM between two frames""" if frame1 is None or frame2 is None: return None try: # Convert to grayscale for SSIM computation gray1 = ( cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY) if len(frame1.shape) == 3 else frame1 ) gray2 = ( cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY) if len(frame2.shape) == 3 else frame2 ) # Ensure both frames have the same dimensions if gray1.shape != gray2.shape: # Resize to match the smaller dimension h = min(gray1.shape[0], gray2.shape[0]) w = min(gray1.shape[1], gray2.shape[1]) gray1 = cv2.resize(gray1, (w, h)) gray2 = cv2.resize(gray2, (w, h)) # Compute SSIM ssim_value = ssim(gray1, gray2, data_range=255) return ssim_value except Exception as e: print(f"SSIM computation failed: {e}") return None def compute_ms_ssim(self, frame1, frame2): """Compute Multi-Scale SSIM between two frames""" if frame1 is None or frame2 is None: return None try: # Convert to grayscale for MS-SSIM computation gray1 = ( cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY) if len(frame1.shape) == 3 else frame1 ) gray2 = ( cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY) if len(frame2.shape) == 3 else frame2 ) # Ensure both frames have the same dimensions if gray1.shape != gray2.shape: h = min(gray1.shape[0], gray2.shape[0]) w = min(gray1.shape[1], gray2.shape[1]) gray1 = cv2.resize(gray1, (w, h)) gray2 = cv2.resize(gray2, (w, h)) # Ensure minimum size for multi-scale analysis min_size = 32 if min(gray1.shape) < min_size: return None # Compute MS-SSIM using multiple scales from skimage.metrics import structural_similarity # Use win_size that works with image dimensions win_size = min(7, min(gray1.shape) // 4) if win_size < 3: win_size = 3 ms_ssim_val = structural_similarity( gray1, gray2, data_range=255, win_size=win_size, multichannel=False ) return ms_ssim_val except Exception as e: print(f"MS-SSIM computation failed: {e}") return None def compute_psnr(self, frame1, frame2): """Compute PSNR between two frames""" if frame1 is None or frame2 is None: return None try: # Ensure both frames have the same dimensions if frame1.shape != frame2.shape: h = min(frame1.shape[0], frame2.shape[0]) w = min(frame1.shape[1], frame2.shape[1]) c = ( min(frame1.shape[2], frame2.shape[2]) if len(frame1.shape) == 3 else 1 ) if len(frame1.shape) == 3: frame1 = cv2.resize(frame1, (w, h))[:, :, :c] frame2 = cv2.resize(frame2, (w, h))[:, :, :c] else: frame1 = cv2.resize(frame1, (w, h)) frame2 = cv2.resize(frame2, (w, h)) # Compute PSNR return psnr_skimage(frame1, frame2, data_range=255) except Exception as e: print(f"PSNR computation failed: {e}") return None def compute_mse(self, frame1, frame2): """Compute MSE between two frames""" if frame1 is None or frame2 is None: return None try: # Ensure both frames have the same dimensions if frame1.shape != frame2.shape: h = min(frame1.shape[0], frame2.shape[0]) w = min(frame1.shape[1], frame2.shape[1]) c = ( min(frame1.shape[2], frame2.shape[2]) if len(frame1.shape) == 3 else 1 ) if len(frame1.shape) == 3: frame1 = cv2.resize(frame1, (w, h))[:, :, :c] frame2 = cv2.resize(frame2, (w, h))[:, :, :c] else: frame1 = cv2.resize(frame1, (w, h)) frame2 = cv2.resize(frame2, (w, h)) # Compute MSE return mse_skimage(frame1, frame2) except Exception as e: print(f"MSE computation failed: {e}") return None def compute_phash(self, frame1, frame2): """Compute perceptual hash similarity between two frames""" if frame1 is None or frame2 is None: return None try: # Convert to PIL Images for imagehash pil1 = Image.fromarray(frame1) pil2 = Image.fromarray(frame2) # Compute perceptual hashes hash1 = imagehash.phash(pil1) hash2 = imagehash.phash(pil2) # Calculate similarity (lower hamming distance = more similar) hamming_distance = hash1 - hash2 # Convert to similarity score (0-1, where 1 is identical) max_distance = len(str(hash1)) * 4 # 4 bits per hex char similarity = 1 - (hamming_distance / max_distance) return similarity except Exception as e: print(f"pHash computation failed: {e}") return None def compute_color_histogram_correlation(self, frame1, frame2): """Compute color histogram correlation between two frames""" if frame1 is None or frame2 is None: return None try: # Ensure both frames have the same dimensions if frame1.shape != frame2.shape: h = min(frame1.shape[0], frame2.shape[0]) w = min(frame1.shape[1], frame2.shape[1]) frame1 = cv2.resize(frame1, (w, h)) frame2 = cv2.resize(frame2, (w, h)) # Compute histograms for each channel correlations = [] if len(frame1.shape) == 3: # Color image for i in range(3): # R, G, B channels hist1 = cv2.calcHist([frame1], [i], None, [256], [0, 256]) hist2 = cv2.calcHist([frame2], [i], None, [256], [0, 256]) # Flatten histograms hist1 = hist1.flatten() hist2 = hist2.flatten() # Compute correlation if np.std(hist1) > 0 and np.std(hist2) > 0: corr, _ = pearsonr(hist1, hist2) correlations.append(corr) # Return average correlation across channels return np.mean(correlations) if correlations else 0.0 else: # Grayscale hist1 = cv2.calcHist([frame1], [0], None, [256], [0, 256]).flatten() hist2 = cv2.calcHist([frame2], [0], None, [256], [0, 256]).flatten() if np.std(hist1) > 0 and np.std(hist2) > 0: corr, _ = pearsonr(hist1, hist2) return corr else: return 0.0 except Exception as e: print(f"Color histogram correlation computation failed: {e}") return None def compute_sharpness(self, frame): """Compute sharpness using Laplacian variance method""" if frame is None: return None # Convert to grayscale if needed gray = ( cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) if len(frame.shape) == 3 else frame ) # Compute Laplacian variance (higher values = sharper) laplacian = cv2.Laplacian(gray, cv2.CV_64F) sharpness = laplacian.var() return sharpness def compute_frame_metrics(self, frame1, frame2, frame_idx): """Compute all metrics for a frame pair""" metrics = { "frame_index": frame_idx, "ssim": self.compute_ssim(frame1, frame2), "psnr": self.compute_psnr(frame1, frame2), "mse": self.compute_mse(frame1, frame2), "phash": self.compute_phash(frame1, frame2), "color_hist_corr": self.compute_color_histogram_correlation(frame1, frame2), "sharpness1": self.compute_sharpness(frame1), "sharpness2": self.compute_sharpness(frame2), } # Compute average sharpness for the pair if metrics["sharpness1"] is not None and metrics["sharpness2"] is not None: metrics["sharpness_avg"] = ( metrics["sharpness1"] + metrics["sharpness2"] ) / 2 metrics["sharpness_diff"] = abs( metrics["sharpness1"] - metrics["sharpness2"] ) else: metrics["sharpness_avg"] = None metrics["sharpness_diff"] = None return metrics def compute_all_metrics(self, frames1, frames2): """Compute metrics for all frame pairs""" all_metrics = [] max_frames = max(len(frames1), len(frames2)) for i in range(max_frames): frame1 = frames1[i] if i < len(frames1) else None frame2 = frames2[i] if i < len(frames2) else None if frame1 is not None or frame2 is not None: metrics = self.compute_frame_metrics(frame1, frame2, i) all_metrics.append(metrics) else: # Handle cases where both frames are missing all_metrics.append( { "frame_index": i, "ssim": None, "ms_ssim": None, "psnr": None, "mse": None, "phash": None, "color_hist_corr": None, "sharpness1": None, "sharpness2": None, "sharpness_avg": None, "sharpness_diff": None, } ) return all_metrics def get_metric_summary(self, metrics_list): """Compute summary statistics for all metrics""" metric_names = [ "ssim", "psnr", "mse", "phash", "color_hist_corr", "sharpness1", "sharpness2", "sharpness_avg", "sharpness_diff", ] summary = { "total_frames": len(metrics_list), "valid_frames": len([m for m in metrics_list if m.get("ssim") is not None]), } # Compute statistics for each metric for metric_name in metric_names: valid_values = [ m[metric_name] for m in metrics_list if m.get(metric_name) is not None ] if valid_values: summary.update( { f"{metric_name}_mean": np.mean(valid_values), f"{metric_name}_min": np.min(valid_values), f"{metric_name}_max": np.max(valid_values), f"{metric_name}_std": np.std(valid_values), } ) return summary def create_individual_metric_plots(self, metrics_list, current_frame=0): """Create individual plots for each metric with frame on x-axis""" if not metrics_list: return None # Extract frame indices frame_indices = [m["frame_index"] for m in metrics_list] # Helper function to get valid data def get_valid_data(metric_name): values = [m.get(metric_name) for m in metrics_list] valid_indices = [i for i, v in enumerate(values) if v is not None] valid_values = [values[i] for i in valid_indices] valid_frames = [frame_indices[i] for i in valid_indices] return valid_frames, valid_values # Create individual plots for each metric plots = {} # 1. SSIM Plot ssim_frames, ssim_values = get_valid_data("ssim") if ssim_values: # Calculate dynamic y-axis range for SSIM to highlight differences min_ssim = min(ssim_values) max_ssim = max(ssim_values) ssim_range = max_ssim - min_ssim # If there's very little variation, zoom in to show differences if ssim_range < 0.05: # For small variations, zoom in to show differences better center = (min_ssim + max_ssim) / 2 padding = max( 0.02, ssim_range * 2 ) # At least 0.02 range or 2x actual range y_min = max(0, center - padding) y_max = min(1, center + padding) else: # For larger variations, add some padding padding = ssim_range * 0.15 # 15% padding y_min = max(0, min_ssim - padding) y_max = min(1, max_ssim + padding) fig_ssim = go.Figure() # Add area fill to emphasize the curve fig_ssim.add_trace( go.Scatter( x=ssim_frames, y=[y_min] * len(ssim_frames), mode="lines", line=dict( color="rgba(0,0,255,0)" ), # Transparent line for area base showlegend=False, hoverinfo="skip", ) ) fig_ssim.add_trace( go.Scatter( x=ssim_frames, y=ssim_values, mode="lines+markers", name="SSIM", line=dict(color="blue", width=3), marker=dict( size=6, color="blue", line=dict(color="darkblue", width=1) ), hovertemplate="Frame %{x}
SSIM: %{y:.5f}", fill="tonexty", fillcolor="rgba(0,0,255,0.1)", # Light blue fill ) ) if current_frame is not None: fig_ssim.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_ssim.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_ssim.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_ssim.update_yaxes( title_text="SSIM", range=[y_min, y_max], gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["ssim"] = fig_ssim # 2. PSNR Plot psnr_frames, psnr_values = get_valid_data("psnr") if psnr_values: fig_psnr = go.Figure() fig_psnr.add_trace( go.Scatter( x=psnr_frames, y=psnr_values, mode="lines+markers", name="PSNR", line=dict(color="green", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
PSNR: %{y:.2f} dB", ) ) if current_frame is not None: fig_psnr.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_psnr.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_psnr.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_psnr.update_yaxes( title_text="PSNR (dB)", gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["psnr"] = fig_psnr # 3. MSE Plot mse_frames, mse_values = get_valid_data("mse") if mse_values: fig_mse = go.Figure() fig_mse.add_trace( go.Scatter( x=mse_frames, y=mse_values, mode="lines+markers", name="MSE", line=dict(color="red", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
MSE: %{y:.2f}", ) ) if current_frame is not None: fig_mse.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_mse.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_mse.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_mse.update_yaxes( title_text="MSE", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) plots["mse"] = fig_mse # 4. pHash Plot phash_frames, phash_values = get_valid_data("phash") if phash_values: fig_phash = go.Figure() fig_phash.add_trace( go.Scatter( x=phash_frames, y=phash_values, mode="lines+markers", name="pHash", line=dict(color="purple", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
pHash: %{y:.4f}", ) ) if current_frame is not None: fig_phash.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_phash.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_phash.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_phash.update_yaxes( title_text="pHash Similarity", gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["phash"] = fig_phash # 5. Color Histogram Correlation Plot hist_frames, hist_values = get_valid_data("color_hist_corr") if hist_values: fig_hist = go.Figure() fig_hist.add_trace( go.Scatter( x=hist_frames, y=hist_values, mode="lines+markers", name="Color Histogram", line=dict(color="orange", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
Color Histogram: %{y:.4f}", ) ) if current_frame is not None: fig_hist.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_hist.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_hist.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_hist.update_yaxes( title_text="Color Histogram Correlation", gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["color_hist"] = fig_hist # 6. Sharpness Comparison Plot sharp1_frames, sharp1_values = get_valid_data("sharpness1") sharp2_frames, sharp2_values = get_valid_data("sharpness2") if sharp1_values or sharp2_values: fig_sharp = go.Figure() if sharp1_values: fig_sharp.add_trace( go.Scatter( x=sharp1_frames, y=sharp1_values, mode="lines+markers", name="Video 1", line=dict(color="darkgreen", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
Video 1 Sharpness: %{y:.1f}", ) ) if sharp2_values: fig_sharp.add_trace( go.Scatter( x=sharp2_frames, y=sharp2_values, mode="lines+markers", name="Video 2", line=dict(color="darkblue", width=3), marker=dict(size=6), hovertemplate="Frame %{x}
Video 2 Sharpness: %{y:.1f}", ) ) if current_frame is not None: fig_sharp.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_sharp.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5 ), dragmode=False, hovermode="x unified", ) fig_sharp.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True ) fig_sharp.update_yaxes( title_text="Sharpness", gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["sharpness"] = fig_sharp # 7. Overall Quality Score Plot (Combination of metrics) # Calculate overall quality score by combining normalized metrics if ssim_values and psnr_values and len(ssim_values) == len(psnr_values): # Get data for metrics that contribute to overall score phash_frames_overall, phash_values_overall = get_valid_data("phash") # Ensure we have the same frames for all metrics common_frames = set(ssim_frames) & set(psnr_frames) if phash_values_overall: common_frames = common_frames & set(phash_frames_overall) common_frames = sorted(list(common_frames)) if common_frames: # Extract values for common frames ssim_common = [ ssim_values[ssim_frames.index(f)] for f in common_frames if f in ssim_frames ] psnr_common = [ psnr_values[psnr_frames.index(f)] for f in common_frames if f in psnr_frames ] # Normalize PSNR to 0-1 scale using min-max normalization if psnr_common: psnr_min = min(psnr_common) psnr_max = max(psnr_common) if psnr_max > psnr_min: psnr_normalized = [ (p - psnr_min) / (psnr_max - psnr_min) for p in psnr_common ] else: psnr_normalized = [0.0 for _ in psnr_common] else: psnr_normalized = [] # Start with SSIM and normalized PSNR quality_components = [ssim_common, psnr_normalized] component_names = ["SSIM", "PSNR"] # Add pHash if available if phash_values_overall: phash_common = [ phash_values_overall[phash_frames_overall.index(f)] for f in common_frames if f in phash_frames_overall ] if len(phash_common) == len(ssim_common): quality_components.append(phash_common) component_names.append("pHash") # Calculate average across all components overall_quality = [] for i in range(len(common_frames)): frame_scores = [ component[i] for component in quality_components if i < len(component) ] overall_quality.append(sum(frame_scores) / len(frame_scores)) # Calculate dynamic y-axis range to emphasize differences min_quality = min(overall_quality) max_quality = max(overall_quality) quality_range = max_quality - min_quality # If there's very little variation, use a smaller range to emphasize small differences if quality_range < 0.08: # For small variations, zoom in to show differences better center = (min_quality + max_quality) / 2 padding = max( 0.04, quality_range * 2 ) # At least 0.04 range or 2x the actual range y_min = max(0, center - padding) y_max = min(1, center + padding) else: # For larger variations, add some padding padding = quality_range * 0.15 # 15% padding y_min = max(0, min_quality - padding) y_max = min(1, max_quality + padding) fig_overall = go.Figure() # Add area fill to emphasize the quality curve fig_overall.add_trace( go.Scatter( x=common_frames, y=[y_min] * len(common_frames), mode="lines", line=dict( color="rgba(255,215,0,0)" ), # Transparent line for area base showlegend=False, hoverinfo="skip", ) ) fig_overall.add_trace( go.Scatter( x=common_frames, y=overall_quality, mode="lines+markers", name="Overall Quality", line=dict(color="gold", width=4), marker=dict( size=8, color="gold", line=dict(color="orange", width=2) ), hovertemplate="Frame %{x}
Overall Quality: %{y:.5f}
Combined from: " + ", ".join(component_names) + "", fill="tonexty", fillcolor="rgba(255,215,0,0.15)", # Semi-transparent gold fill ) ) # Add quality threshold indicators if there are significant variations if current_frame is not None: fig_overall.add_vline( x=current_frame, line_dash="dash", line_color="red", line_width=2, ) fig_overall.update_layout( height=300, margin=dict(t=20, b=40, l=60, r=20), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", showlegend=False, dragmode=False, hovermode="x unified", ) fig_overall.update_xaxes( title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) fig_overall.update_yaxes( title_text="Overall Quality Score", range=[y_min, y_max], gridcolor="rgba(128,128,128,0.4)", fixedrange=True, ) plots["overall"] = fig_overall return plots def create_modern_plot(self, metrics_list, current_frame=0): """Create individual metric plots instead of combined dashboard""" return self.create_individual_metric_plots(metrics_list, current_frame) class VideoFrameComparator: def __init__(self): self.video1_frames = [] self.video2_frames = [] self.max_frames = 0 self.frame_metrics = FrameMetrics() self.computed_metrics = [] self.metrics_summary = {} def extract_frames(self, video_path): """Extract all frames from a video file or URL""" if not video_path: return [] # Check if it's a URL or local file is_url = video_path.startswith(("http://", "https://")) if not is_url and not os.path.exists(video_path): print(f"Warning: Local video file not found: {video_path}") return [] frames = [] cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print( f"Error: Could not open video {'URL' if is_url else 'file'}: {video_path}" ) return [] try: frame_count = 0 while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB for display frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame_rgb) frame_count += 1 # Add progress feedback for URLs (which might be slower) if is_url and frame_count % 30 == 0: print(f"Processed {frame_count} frames from URL...") except Exception as e: print(f"Error processing video: {e}") finally: cap.release() print( f"Successfully extracted {len(frames)} frames from {'URL' if is_url else 'file'}: {video_path}" ) return frames def is_comparison_in_data_json( self, video1_path, video2_path, json_file_path="data.json" ): """Check if this video comparison exists in data.json""" try: with open(json_file_path, "r") as f: data = json.load(f) for comparison in data.get("comparisons", []): videos = comparison.get("videos", []) if len(videos) == 2: # Check both orders (works for both local files and URLs) if (videos[0] == video1_path and videos[1] == video2_path) or ( videos[0] == video2_path and videos[1] == video1_path ): return True return False except Exception: return False def load_videos(self, video1_path, video2_path): """Load both videos and extract frames""" if not video1_path and not video2_path: return "Please upload at least one video.", 0, None, None, "", None # Extract frames from both videos self.video1_frames = self.extract_frames(video1_path) if video1_path else [] self.video2_frames = self.extract_frames(video2_path) if video2_path else [] # Determine maximum number of frames self.max_frames = max(len(self.video1_frames), len(self.video2_frames)) if self.max_frames == 0: return ( "No valid frames found in the uploaded videos.", 0, None, None, "", None, ) # Compute metrics if both videos are present and not in data.json metrics_info = "" plots = None if ( video1_path and video2_path and not self.is_comparison_in_data_json(video1_path, video2_path) ): print("Computing comprehensive frame-by-frame metrics...") self.computed_metrics = self.frame_metrics.compute_all_metrics( self.video1_frames, self.video2_frames ) self.metrics_summary = self.frame_metrics.get_metric_summary( self.computed_metrics ) # Build metrics info string metrics_info = "\n\nšŸ“Š Computed Metrics Summary:\n" metric_display = { "ssim": ("SSIM", ".4f", "", "↑ Higher=Better"), "psnr": ("PSNR", ".2f", " dB", "↑ Higher=Better"), "mse": ("MSE", ".2f", "", "↓ Lower=Better"), "phash": ("pHash", ".4f", "", "↑ Higher=Better"), "color_hist_corr": ("Color Hist", ".4f", "", "↑ Higher=Better"), "sharpness_avg": ("Sharpness", ".1f", "", "↑ Higher=Better"), } for metric_key, ( display_name, format_str, unit, direction, ) in metric_display.items(): if self.metrics_summary.get(f"{metric_key}_mean") is not None: mean_val = self.metrics_summary[f"{metric_key}_mean"] std_val = self.metrics_summary[f"{metric_key}_std"] metrics_info += f"{display_name}: μ={mean_val:{format_str}}{unit}, σ={std_val:{format_str}}{unit} ({direction})\n" metrics_info += f"Valid Frames: {self.metrics_summary['valid_frames']}/{self.metrics_summary['total_frames']}" # Generate initial plot plots = self.frame_metrics.create_individual_metric_plots( self.computed_metrics, 0 ) else: self.computed_metrics = [] self.metrics_summary = {} if video1_path and video2_path: metrics_info = "\n\nšŸ“‹ Note: This comparison is predefined in data.json (metrics not computed)" # Get initial frames frame1 = ( self.video1_frames[0] if self.video1_frames else np.zeros((480, 640, 3), dtype=np.uint8) ) frame2 = ( self.video2_frames[0] if self.video2_frames else np.zeros((480, 640, 3), dtype=np.uint8) ) status_msg = "Videos loaded successfully!\n" status_msg += f"Video 1: {len(self.video1_frames)} frames\n" status_msg += f"Video 2: {len(self.video2_frames)} frames\n" status_msg += ( f"Use the slider to navigate through frames (0-{self.max_frames - 1})" ) status_msg += metrics_info return ( status_msg, self.max_frames - 1, frame1, frame2, self.get_current_frame_info(0), plots, ) def get_frames_at_index(self, frame_index): """Get frames at specific index from both videos""" frame_index = int(frame_index) # Get frame from video 1 if frame_index < len(self.video1_frames): frame1 = self.video1_frames[frame_index] else: # Create a placeholder if frame doesn't exist frame1 = np.zeros((480, 640, 3), dtype=np.uint8) cv2.putText( frame1, f"Frame {frame_index} not available", (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, ) # Get frame from video 2 if frame_index < len(self.video2_frames): frame2 = self.video2_frames[frame_index] else: # Create a placeholder if frame doesn't exist frame2 = np.zeros((480, 640, 3), dtype=np.uint8) cv2.putText( frame2, f"Frame {frame_index} not available", (50, 240), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, ) return frame1, frame2 def get_current_frame_info(self, frame_index): """Get information about the current frame including metrics""" frame_index = int(frame_index) info = f"Current Frame: {frame_index} / {self.max_frames - 1}" # Add metrics info if available if self.computed_metrics and frame_index < len(self.computed_metrics): metrics = self.computed_metrics[frame_index] # === COMPARISON METRICS (Between Videos) === comparison_metrics = [] # SSIM with quality assessment if metrics.get("ssim") is not None: ssim_val = metrics["ssim"] if ssim_val >= 0.9: quality = "🟢 Excellent" elif ssim_val >= 0.8: quality = "šŸ”µ Good" elif ssim_val >= 0.6: quality = "🟔 Fair" else: quality = "šŸ”“ Poor" comparison_metrics.append( f"SSIM: {ssim_val:.4f} ({quality} similarity)" ) # PSNR with quality indicator if metrics.get("psnr") is not None: psnr_val = metrics["psnr"] if psnr_val >= 40: psnr_quality = "🟢 Excellent" elif psnr_val >= 30: psnr_quality = "šŸ”µ Good" elif psnr_val >= 20: psnr_quality = "🟔 Fair" else: psnr_quality = "šŸ”“ Poor" comparison_metrics.append( f"PSNR: {psnr_val:.1f}dB ({psnr_quality} signal quality)" ) # MSE with quality indicator (lower is better) if metrics.get("mse") is not None: mse_val = metrics["mse"] if mse_val <= 50: mse_quality = "🟢 Very Similar" elif mse_val <= 100: mse_quality = "šŸ”µ Similar" elif mse_val <= 200: mse_quality = "🟔 Moderately Different" else: mse_quality = "šŸ”“ Very Different" comparison_metrics.append(f"MSE: {mse_val:.1f} ({mse_quality})") # pHash with quality indicator if metrics.get("phash") is not None: phash_val = metrics["phash"] if phash_val >= 0.95: phash_quality = "🟢 Nearly Identical" elif phash_val >= 0.9: phash_quality = "šŸ”µ Very Similar" elif phash_val >= 0.8: phash_quality = "🟔 Somewhat Similar" else: phash_quality = "šŸ”“ Different" comparison_metrics.append( f"pHash: {phash_val:.3f} ({phash_quality} perceptually)" ) # Color Histogram Correlation if metrics.get("color_hist_corr") is not None: color_val = metrics["color_hist_corr"] if color_val >= 0.9: color_quality = "🟢 Very Similar Colors" elif color_val >= 0.8: color_quality = "šŸ”µ Similar Colors" elif color_val >= 0.6: color_quality = "🟔 Moderate Color Diff" else: color_quality = "šŸ”“ Different Colors" comparison_metrics.append(f"Color: {color_val:.3f} ({color_quality})") # Add comparison metrics to info if comparison_metrics: info += "\nšŸ“Š Comparison Analysis: " + " | ".join(comparison_metrics) # === INDIVIDUAL VIDEO QUALITY === individual_metrics = [] # Individual Sharpness for each video if metrics.get("sharpness1") is not None: sharp1 = metrics["sharpness1"] if sharp1 >= 200: sharp1_quality = "🟢 Sharp" elif sharp1 >= 100: sharp1_quality = "šŸ”µ Moderate" elif sharp1 >= 50: sharp1_quality = "🟔 Soft" else: sharp1_quality = "šŸ”“ Blurry" individual_metrics.append( f"V1 Sharpness: {sharp1:.0f} ({sharp1_quality})" ) if metrics.get("sharpness2") is not None: sharp2 = metrics["sharpness2"] if sharp2 >= 200: sharp2_quality = "🟢 Sharp" elif sharp2 >= 100: sharp2_quality = "šŸ”µ Moderate" elif sharp2 >= 50: sharp2_quality = "🟔 Soft" else: sharp2_quality = "šŸ”“ Blurry" individual_metrics.append( f"V2 Sharpness: {sharp2:.0f} ({sharp2_quality})" ) # Sharpness comparison if ( metrics.get("sharpness1") is not None and metrics.get("sharpness2") is not None ): sharp1 = metrics["sharpness1"] sharp2 = metrics["sharpness2"] # Calculate difference percentage diff_pct = abs(sharp1 - sharp2) / max(sharp1, sharp2) * 100 # Determine significance with clearer labels if diff_pct > 20: significance = "šŸ”“ MAJOR difference" elif diff_pct > 10: significance = "🟔 MODERATE difference" elif diff_pct > 5: significance = "šŸ”µ MINOR difference" else: significance = "🟢 NEGLIGIBLE difference" # Determine which is sharper if sharp1 > sharp2: comparison = "V1 is sharper" elif sharp2 > sharp1: comparison = "V2 is sharper" else: comparison = "Equal sharpness" individual_metrics.append(f"Sharpness: {comparison} ({significance})") # Add individual metrics to info if individual_metrics: info += "\nšŸŽÆ Individual Quality: " + " | ".join(individual_metrics) # === OVERALL QUALITY ASSESSMENT === # Calculate combined quality score from multiple metrics quality_score = 0 quality_count = 0 metric_contributions = [] # SSIM contribution if metrics.get("ssim") is not None: quality_score += metrics["ssim"] quality_count += 1 metric_contributions.append(f"SSIM({metrics['ssim']:.3f})") # PSNR contribution (normalized to 0-1 scale) if metrics.get("psnr") is not None: psnr_norm = min(metrics["psnr"] / 50, 1.0) quality_score += psnr_norm quality_count += 1 metric_contributions.append(f"PSNR({psnr_norm:.3f})") # pHash contribution if metrics.get("phash") is not None: quality_score += metrics["phash"] quality_count += 1 metric_contributions.append(f"pHash({metrics['phash']:.3f})") if quality_count > 0: avg_quality = quality_score / quality_count # Add overall assessment with formula explanation if avg_quality >= 0.9: overall = "✨ Excellent Overall" quality_indicator = "🟢" elif avg_quality >= 0.8: overall = "āœ… Good Overall" quality_indicator = "šŸ”µ" elif avg_quality >= 0.6: overall = "āš ļø Fair Overall" quality_indicator = "🟔" else: overall = "āŒ Poor Overall" quality_indicator = "šŸ”“" # Calculate quality variation across all frames to show differences quality_variation = "" if self.computed_metrics and len(self.computed_metrics) > 1: # Calculate overall quality for all frames to show variation all_quality_scores = [] for metric in self.computed_metrics: frame_quality = 0 frame_quality_count = 0 if metric.get("ssim") is not None: frame_quality += metric["ssim"] frame_quality_count += 1 if metric.get("psnr") is not None: frame_quality += min(metric["psnr"] / 50, 1.0) frame_quality_count += 1 if metric.get("phash") is not None: frame_quality += metric["phash"] frame_quality_count += 1 if frame_quality_count > 0: all_quality_scores.append( frame_quality / frame_quality_count ) if len(all_quality_scores) > 1: min_qual = min(all_quality_scores) max_qual = max(all_quality_scores) variation = max_qual - min_qual if variation > 0.08: quality_variation = ( f" | šŸ“Š High Variation (Ī”{variation:.4f})" ) elif variation > 0.04: quality_variation = ( f" | šŸ“Š Moderate Variation (Ī”{variation:.4f})" ) elif variation > 0.02: quality_variation = ( f" | šŸ“Š Low Variation (Ī”{variation:.4f})" ) else: quality_variation = ( f" | šŸ“Š Stable Quality (Ī”{variation:.4f})" ) info += f"\nšŸŽÆ Overall Quality: {quality_indicator} {avg_quality:.5f} ({overall}){quality_variation}" info += f"\n šŸ’” Formula: Average of {' + '.join(metric_contributions)} = {avg_quality:.5f}" return info def get_updated_plot(self, frame_index): """Get updated plot with current frame highlighted""" if self.computed_metrics: return self.frame_metrics.create_individual_metric_plots( self.computed_metrics, int(frame_index) ) return None def load_examples_from_json(json_file_path="data.json"): """Load example video pairs from JSON configuration file""" try: with open(json_file_path, "r") as f: data = json.load(f) examples = [] # Extract video pairs from the comparisons for comparison in data.get("comparisons", []): videos = comparison.get("videos", []) # Validate that video files/URLs exist or are accessible valid_videos = [] for video_path in videos: if video_path: # Check if not empty/None # Check if it's a URL if video_path.startswith(("http://", "https://")): # For URLs, we'll assume they're valid (can't easily check without downloading) # OpenCV will handle the validation during actual loading valid_videos.append(video_path) print(f"Added video URL: {video_path}") else: # Convert to absolute path for local files abs_path = os.path.abspath(video_path) if os.path.exists(abs_path): valid_videos.append(abs_path) print(f"Added local video file: {abs_path}") elif os.path.exists(video_path): # Try relative path as fallback valid_videos.append(video_path) print(f"Added local video file: {video_path}") else: print( f"Warning: Local video file not found: {video_path} (abs: {abs_path})" ) # Add to examples if we have valid videos if len(valid_videos) == 2: examples.append(valid_videos) elif len(valid_videos) == 1: # Single video example (compare with None) examples.append([valid_videos[0], None]) return examples except FileNotFoundError: print(f"Warning: {json_file_path} not found. No examples will be loaded.") return [] except json.JSONDecodeError as e: print(f"Error parsing {json_file_path}: {e}") return [] except Exception as e: print(f"Error loading examples: {e}") return [] def get_all_videos_from_json(json_file_path="data.json"): """Get list of all unique videos mentioned in the JSON file""" try: with open(json_file_path, "r") as f: data = json.load(f) all_videos = set() # Extract all unique video paths/URLs from comparisons for comparison in data.get("comparisons", []): videos = comparison.get("videos", []) for video_path in videos: if video_path: # Only add non-empty paths # Check if it's a URL or local file if video_path.startswith(("http://", "https://")): # For URLs, add them directly all_videos.add(video_path) elif os.path.exists(video_path): # For local files, check existence before adding all_videos.add(video_path) return sorted(list(all_videos)) except FileNotFoundError: print(f"Warning: {json_file_path} not found.") return [] except json.JSONDecodeError as e: print(f"Error parsing {json_file_path}: {e}") return [] except Exception as e: print(f"Error loading videos: {e}") return [] def create_app(): comparator = VideoFrameComparator() example_pairs = load_examples_from_json() print(f"DEBUG: Loaded {len(example_pairs)} example pairs") for i, pair in enumerate(example_pairs): print(f" Example {i + 1}: {pair}") with gr.Blocks( title="Frame Arena - Video Frame Comparator", # theme=gr.themes.Soft(), fill_width=True, css=""" /* Ensure plots adapt to theme */ .plotly .main-svg { color: var(--body-text-color, #000) !important; } /* Grid visibility for both themes */ .plotly .gridlayer .xgrid, .plotly .gridlayer .ygrid { stroke-opacity: 0.4 !important; } /* Axis text color adaptation */ .plotly .xtick text, .plotly .ytick text { fill: var(--body-text-color, #000) !important; } /* Axis title color adaptation - multiple selectors for better coverage */ .plotly .g-xtitle, .plotly .g-ytitle, .plotly .xtitle, .plotly .ytitle, .plotly text[class*="xtitle"], .plotly text[class*="ytitle"], .plotly .infolayer .g-xtitle, .plotly .infolayer .g-ytitle { fill: var(--body-text-color, #000) !important; } /* Additional axis title selectors */ .plotly .subplot .xtitle, .plotly .subplot .ytitle, .plotly .cartesianlayer .xtitle, .plotly .cartesianlayer .ytitle { fill: var(--body-text-color, #000) !important; } /* SVG text elements in plots */ .plotly svg text { fill: var(--body-text-color, #000) !important; } /* Legend text color */ .plotly .legendtext, .plotly .legend text { fill: var(--body-text-color, #000) !important; } /* Hover label adaptation */ .plotly .hoverlayer .hovertext, .plotly .hovertext { fill: var(--body-text-color, #000) !important; color: var(--body-text-color, #000) !important; } /* Annotation text */ .plotly .annotation-text, .plotly .annotation { fill: var(--body-text-color, #000) !important; } /* Disable plot interactions except hover */ .plotly .modebar { display: none !important; } .plotly .plot-container .plotly { pointer-events: none !important; } .plotly .plot-container .plotly .hoverlayer { pointer-events: auto !important; } .plotly .plot-container .plotly .hovertext { pointer-events: auto !important; } """, ) as app: gr.Markdown(""" # šŸŽ¬ Frame Arena: Frame by frame comparisons of any videos > šŸŽ‰ This tool has been created to celebrate our Wan 2.2 [text-to-video](https://replicate.com/wan-video/wan-2.2-t2v-480p-fast) and [image-to-video](https://replicate.com/wan-video/wan-2.2-i2v-a14b) endpoints on Replicate. Want to know more? Check out [our blog](https://www.wan22.com/blog/video-optimization-on-replicate)! - Upload videos in common formats with the same number of frames (MP4, AVI, MOV, etc.) or use URLs - **7 Quality Metrics**: SSIM, PSNR, MSE, pHash, Color Histogram, Sharpness + Overall Quality - **Individual Visualization**: Each metric gets its own dedicated plot - **Real-time Analysis**: Navigate frames with live metric updates - **Smart Comparisons**: Understand differences between videos per metric **Perfect for**: Analyzing compression effects, processing artifacts, visual quality assessment, and compression algorithm comparisons. """) with gr.Row(): with gr.Column(): gr.Markdown("### Video 1") video1_input = gr.File( label="Upload Video 1", file_types=[ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ], type="filepath", ) with gr.Column(): gr.Markdown("### Video 2") video2_input = gr.File( label="Upload Video 2", file_types=[ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ], type="filepath", ) # Add examples at the top for better UX if example_pairs: gr.Markdown("### šŸ“ Example Video Comparisons") gr.Examples( examples=example_pairs, inputs=[video1_input, video2_input], label="Click any example to load video pairs:", examples_per_page=10, run_on_click=False, # We'll handle this manually ) load_btn = gr.Button("šŸ”„ Load Videos", variant="primary", size="lg") # Frame comparison section (initially hidden) frame_display = gr.Row(visible=True) with frame_display: with gr.Column(): gr.Markdown("### Video 1 - Current Frame") frame1_output = gr.Image( label="Video 1 Frame", type="numpy", interactive=False, # height=400, ) with gr.Column(): gr.Markdown("### Frame Slider (Left: Video 1, Right: Video 2)") image_slider = ImageSlider( label="Drag to compare frames", type="numpy", interactive=True, # height=400, ) with gr.Column(): gr.Markdown("### Video 2 - Current Frame") frame2_output = gr.Image( label="Video 2 Frame", type="numpy", interactive=False, # height=400, ) # Frame navigation (initially hidden) - moved underneath frames frame_controls = gr.Row(visible=True) with frame_controls: frame_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Frame Number", interactive=True, ) # Comprehensive metrics visualization (initially hidden) metrics_section = gr.Row(visible=True) with metrics_section: with gr.Column(): gr.Markdown("### šŸ“Š Metric Analysis") # Overall quality plot with gr.Row(): overall_plot = gr.Plot( label="Overall Quality (Combined Metric [SSIM + normalized_PSNR + pHash])", show_label=True, ) # Frame info moved below overall quality plot frame_info = gr.Textbox( label="Frame Information & Metrics", interactive=False, value="", lines=3, ) # Add comprehensive usage guide underneath frame information & metrics with gr.Accordion("šŸ“– Usage Guide & Metrics Reference", open=False): with gr.Row(): with gr.Column(): gr.Markdown(""" ### šŸ“Š Metrics Explained - **SSIM**: Structural Similarity (1.0 = identical structure, 0.0 = completely different) - **PSNR**: Peak Signal-to-Noise Ratio in dB (higher = better quality, less noise) - **MSE**: Mean Squared Error (lower = more similar pixel values) - **pHash**: Perceptual Hash similarity (1.0 = visually identical) - **Color Histogram**: Color distribution correlation (1.0 = identical color patterns) - **Sharpness**: Laplacian variance per video (higher = sharper/more detailed images) - **Overall Quality**: Combined metric averaging SSIM, min-max normalized PSNR, and pHash """) with gr.Column() as info_section: status_output = gr.Textbox( label="Status", interactive=False, lines=16 ) with gr.Row(): with gr.Column(): gr.Markdown(""" ### šŸŽÆ Quality Assessment Scale (Research-Based Thresholds) **SSIM Scale** (based on human perception studies): - 🟢 **Excellent (≄0.9)**: Visually indistinguishable differences - šŸ”µ **Good (≄0.8)**: Minor visible differences, still high quality - 🟔 **Fair (≄0.6)**: Noticeable differences, acceptable quality - šŸ”“ **Poor (<0.6)**: Significant visible artifacts and differences **PSNR Scale** (standard video quality benchmarks): - 🟢 **Excellent (≄40dB)**: Professional broadcast quality - šŸ”µ **Good (≄30dB)**: High consumer video quality - 🟔 **Fair (≄20dB)**: Acceptable for web streaming - šŸ”“ **Poor (<20dB)**: Low quality with visible compression artifacts **MSE Scale** (pixel difference thresholds): - 🟢 **Very Similar (≤50)**: Minimal pixel-level differences - šŸ”µ **Similar (≤100)**: Small differences, good quality preservation - 🟔 **Moderately Different (≤200)**: Noticeable but acceptable differences - šŸ”“ **Very Different (>200)**: Significant pixel-level changes """) with gr.Column(): gr.Markdown(""" ### šŸ” Understanding Comparisons **Comparison Analysis**: Shows how similar/different the videos are - Most metrics indicate **similarity** - not which video "wins" - Higher SSIM/PSNR/pHash/Color = more similar videos - Lower MSE = more similar videos **Individual Quality**: Shows the quality of each video separately - Sharpness comparison shows which video has more detail - Significance levels: šŸ”“ MAJOR (>20%), 🟔 MODERATE (10-20%), šŸ”µ MINOR (5-10%), 🟢 NEGLIGIBLE (<5%) **Overall Quality**: Combines multiple metrics to provide a single similarity score - **Formula**: Average of [SSIM + normalized_PSNR + pHash] - **PSNR Normalization**: PSNR values are divided by 50dB and capped at 1.0 - **Range**: 0.0 to 1.0 (higher = more similar videos overall) - **Purpose**: Provides a single metric when you need one overall assessment - **Limitation**: Different metrics may disagree; check individual metrics for details """) # Individual metric plots with gr.Row(): ssim_plot = gr.Plot(label="SSIM", show_label=True) psnr_plot = gr.Plot(label="PSNR", show_label=True) with gr.Row(): mse_plot = gr.Plot(label="MSE", show_label=True) phash_plot = gr.Plot(label="pHash", show_label=True) with gr.Row(): color_plot = gr.Plot(label="Color Histogram", show_label=True) sharpness_plot = gr.Plot(label="Sharpness", show_label=True) # Connect examples to auto-loading if example_pairs: # Use a manual approach to handle examples click def examples_manual_handler(video1, video2): print("DEBUG: Examples clicked manually!") return load_videos_handler(video1, video2) # Since we can't directly attach to examples, we'll use the change events # Event handlers def load_videos_handler(video1, video2): print( f"DEBUG: load_videos_handler called with video1={video1}, video2={video2}" ) status, max_frames, frame1, frame2, info, plots = comparator.load_videos( video1, video2 ) # Update slider slider_update = gr.Slider( minimum=0, maximum=max_frames, step=1, value=0, interactive=True if max_frames > 0 else False, ) # Show/hide sections based on whether videos were loaded successfully videos_loaded = max_frames > 0 # Extract individual plots from the plots dictionary ssim_fig = plots.get("ssim") if plots else None psnr_fig = plots.get("psnr") if plots else None mse_fig = plots.get("mse") if plots else None phash_fig = plots.get("phash") if plots else None color_fig = plots.get("color_hist") if plots else None sharpness_fig = plots.get("sharpness") if plots else None overall_fig = plots.get("overall") if plots else None return ( status, # status_output slider_update, # frame_slider frame1, # frame1_output (frame1, frame2), # image_slider frame2, # frame2_output info, # frame_info ssim_fig, # ssim_plot psnr_fig, # psnr_plot mse_fig, # mse_plot phash_fig, # phash_plot color_fig, # color_plot sharpness_fig, # sharpness_plot overall_fig, # overall_plot gr.Row(visible=videos_loaded), # frame_controls gr.Row(visible=videos_loaded), # frame_display gr.Row(visible=videos_loaded and plots is not None), # metrics_section gr.Row(visible=videos_loaded), # info_section ) def update_frames(frame_index): if comparator.max_frames == 0: return ( None, None, None, "No videos loaded", None, None, None, None, None, None, None, ) frame1, frame2 = comparator.get_frames_at_index(frame_index) info = comparator.get_current_frame_info(frame_index) plots = comparator.get_updated_plot(frame_index) # Extract individual plots from the plots dictionary ssim_fig = plots.get("ssim") if plots else None psnr_fig = plots.get("psnr") if plots else None mse_fig = plots.get("mse") if plots else None phash_fig = plots.get("phash") if plots else None color_fig = plots.get("color_hist") if plots else None sharpness_fig = plots.get("sharpness") if plots else None overall_fig = plots.get("overall") if plots else None return ( frame1, (frame1, frame2), frame2, info, ssim_fig, psnr_fig, mse_fig, phash_fig, color_fig, sharpness_fig, overall_fig, ) # Auto-load when examples populate the inputs def auto_load_when_examples_change(video1, video2): print( f"DEBUG: auto_load_when_examples_change called with video1={video1}, video2={video2}" ) # Only auto-load if both inputs are provided (from examples) if video1 and video2: print("DEBUG: Both videos present, calling load_videos_handler") return load_videos_handler(video1, video2) # If only one or no videos, return default empty state print("DEBUG: Not both videos present, returning default state") return ( "Please upload videos or select an example", # status_output gr.Slider( minimum=0, maximum=0, step=1, value=0, interactive=False ), # frame_slider None, # frame1_output (None, None), # image_slider None, # frame2_output "", # frame_info None, # ssim_plot None, # psnr_plot None, # mse_plot None, # phash_plot None, # color_plot None, # sharpness_plot None, # overall_plot gr.Row(visible=True), # frame_controls gr.Row(visible=True), # frame_display gr.Row(visible=True), # metrics_section gr.Row(visible=True), # info_section ) # Enhanced auto-load function with debouncing to prevent multiple rapid calls last_processed_pair = {"video1": None, "video2": None} def enhanced_auto_load(video1, video2): print(f"DEBUG: Input change detected! video1={video1}, video2={video2}") # Simple debouncing: skip if same video pair was just processed if ( last_processed_pair["video1"] == video1 and last_processed_pair["video2"] == video2 ): print("DEBUG: Same video pair already processed, skipping...") # Return current state without recomputing return ( gr.update(), # status_output gr.update(), # frame_slider gr.update(), # frame1_output gr.update(), # image_slider gr.update(), # frame2_output gr.update(), # frame_info gr.update(), # ssim_plot gr.update(), # psnr_plot gr.update(), # mse_plot gr.update(), # phash_plot gr.update(), # color_plot gr.update(), # sharpness_plot gr.update(), # overall_plot gr.update(), # frame_controls gr.update(), # frame_display gr.update(), # metrics_section gr.update(), # info_section ) last_processed_pair["video1"] = video1 last_processed_pair["video2"] = video2 return auto_load_when_examples_change(video1, video2) # Auto-load when both video inputs change (triggered by examples) video1_input.change( fn=enhanced_auto_load, inputs=[video1_input, video2_input], outputs=[ status_output, frame_slider, frame1_output, image_slider, frame2_output, frame_info, ssim_plot, psnr_plot, mse_plot, phash_plot, color_plot, sharpness_plot, overall_plot, frame_controls, frame_display, metrics_section, info_section, ], ) video2_input.change( fn=enhanced_auto_load, inputs=[video1_input, video2_input], outputs=[ status_output, frame_slider, frame1_output, image_slider, frame2_output, frame_info, ssim_plot, psnr_plot, mse_plot, phash_plot, color_plot, sharpness_plot, overall_plot, frame_controls, frame_display, metrics_section, info_section, ], ) # Manual load button event handler with debug def debug_load_videos_handler(video1, video2): print(f"DEBUG: Load button clicked! video1={video1}, video2={video2}") return load_videos_handler(video1, video2) load_btn.click( fn=debug_load_videos_handler, inputs=[video1_input, video2_input], outputs=[ status_output, frame_slider, frame1_output, image_slider, frame2_output, frame_info, ssim_plot, psnr_plot, mse_plot, phash_plot, color_plot, sharpness_plot, overall_plot, frame_controls, frame_display, metrics_section, info_section, ], ) frame_slider.change( fn=update_frames, inputs=[frame_slider], outputs=[ frame1_output, image_slider, frame2_output, frame_info, ssim_plot, psnr_plot, mse_plot, phash_plot, color_plot, sharpness_plot, overall_plot, ], ) return app def main(): app = create_app() app.launch( server_name="0.0.0.0", server_port=7860, share=False, ) if __name__ == "__main__": main()