import json
import os
import cv2
import gradio as gr
import imagehash
import numpy as np
import plotly.graph_objects as go
from gradio_imageslider import ImageSlider
from PIL import Image
from scipy.stats import pearsonr
from skimage.metrics import mean_squared_error as mse_skimage
from skimage.metrics import peak_signal_noise_ratio as psnr_skimage
from skimage.metrics import structural_similarity as ssim
class FrameMetrics:
"""Class to compute and store frame-by-frame metrics"""
def __init__(self):
self.metrics = {}
def compute_ssim(self, frame1, frame2):
"""Compute SSIM between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Convert to grayscale for SSIM computation
gray1 = (
cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY)
if len(frame1.shape) == 3
else frame1
)
gray2 = (
cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY)
if len(frame2.shape) == 3
else frame2
)
# Ensure both frames have the same dimensions
if gray1.shape != gray2.shape:
# Resize to match the smaller dimension
h = min(gray1.shape[0], gray2.shape[0])
w = min(gray1.shape[1], gray2.shape[1])
gray1 = cv2.resize(gray1, (w, h))
gray2 = cv2.resize(gray2, (w, h))
# Compute SSIM
ssim_value = ssim(gray1, gray2, data_range=255)
return ssim_value
except Exception as e:
print(f"SSIM computation failed: {e}")
return None
def compute_ms_ssim(self, frame1, frame2):
"""Compute Multi-Scale SSIM between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Convert to grayscale for MS-SSIM computation
gray1 = (
cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY)
if len(frame1.shape) == 3
else frame1
)
gray2 = (
cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY)
if len(frame2.shape) == 3
else frame2
)
# Ensure both frames have the same dimensions
if gray1.shape != gray2.shape:
h = min(gray1.shape[0], gray2.shape[0])
w = min(gray1.shape[1], gray2.shape[1])
gray1 = cv2.resize(gray1, (w, h))
gray2 = cv2.resize(gray2, (w, h))
# Ensure minimum size for multi-scale analysis
min_size = 32
if min(gray1.shape) < min_size:
return None
# Compute MS-SSIM using multiple scales
from skimage.metrics import structural_similarity
# Use win_size that works with image dimensions
win_size = min(7, min(gray1.shape) // 4)
if win_size < 3:
win_size = 3
ms_ssim_val = structural_similarity(
gray1, gray2, data_range=255, win_size=win_size, multichannel=False
)
return ms_ssim_val
except Exception as e:
print(f"MS-SSIM computation failed: {e}")
return None
def compute_psnr(self, frame1, frame2):
"""Compute PSNR between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Ensure both frames have the same dimensions
if frame1.shape != frame2.shape:
h = min(frame1.shape[0], frame2.shape[0])
w = min(frame1.shape[1], frame2.shape[1])
c = (
min(frame1.shape[2], frame2.shape[2])
if len(frame1.shape) == 3
else 1
)
if len(frame1.shape) == 3:
frame1 = cv2.resize(frame1, (w, h))[:, :, :c]
frame2 = cv2.resize(frame2, (w, h))[:, :, :c]
else:
frame1 = cv2.resize(frame1, (w, h))
frame2 = cv2.resize(frame2, (w, h))
# Compute PSNR
return psnr_skimage(frame1, frame2, data_range=255)
except Exception as e:
print(f"PSNR computation failed: {e}")
return None
def compute_mse(self, frame1, frame2):
"""Compute MSE between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Ensure both frames have the same dimensions
if frame1.shape != frame2.shape:
h = min(frame1.shape[0], frame2.shape[0])
w = min(frame1.shape[1], frame2.shape[1])
c = (
min(frame1.shape[2], frame2.shape[2])
if len(frame1.shape) == 3
else 1
)
if len(frame1.shape) == 3:
frame1 = cv2.resize(frame1, (w, h))[:, :, :c]
frame2 = cv2.resize(frame2, (w, h))[:, :, :c]
else:
frame1 = cv2.resize(frame1, (w, h))
frame2 = cv2.resize(frame2, (w, h))
# Compute MSE
return mse_skimage(frame1, frame2)
except Exception as e:
print(f"MSE computation failed: {e}")
return None
def compute_phash(self, frame1, frame2):
"""Compute perceptual hash similarity between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Convert to PIL Images for imagehash
pil1 = Image.fromarray(frame1)
pil2 = Image.fromarray(frame2)
# Compute perceptual hashes
hash1 = imagehash.phash(pil1)
hash2 = imagehash.phash(pil2)
# Calculate similarity (lower hamming distance = more similar)
hamming_distance = hash1 - hash2
# Convert to similarity score (0-1, where 1 is identical)
max_distance = len(str(hash1)) * 4 # 4 bits per hex char
similarity = 1 - (hamming_distance / max_distance)
return similarity
except Exception as e:
print(f"pHash computation failed: {e}")
return None
def compute_color_histogram_correlation(self, frame1, frame2):
"""Compute color histogram correlation between two frames"""
if frame1 is None or frame2 is None:
return None
try:
# Ensure both frames have the same dimensions
if frame1.shape != frame2.shape:
h = min(frame1.shape[0], frame2.shape[0])
w = min(frame1.shape[1], frame2.shape[1])
frame1 = cv2.resize(frame1, (w, h))
frame2 = cv2.resize(frame2, (w, h))
# Compute histograms for each channel
correlations = []
if len(frame1.shape) == 3: # Color image
for i in range(3): # R, G, B channels
hist1 = cv2.calcHist([frame1], [i], None, [256], [0, 256])
hist2 = cv2.calcHist([frame2], [i], None, [256], [0, 256])
# Flatten histograms
hist1 = hist1.flatten()
hist2 = hist2.flatten()
# Compute correlation
if np.std(hist1) > 0 and np.std(hist2) > 0:
corr, _ = pearsonr(hist1, hist2)
correlations.append(corr)
# Return average correlation across channels
return np.mean(correlations) if correlations else 0.0
else: # Grayscale
hist1 = cv2.calcHist([frame1], [0], None, [256], [0, 256]).flatten()
hist2 = cv2.calcHist([frame2], [0], None, [256], [0, 256]).flatten()
if np.std(hist1) > 0 and np.std(hist2) > 0:
corr, _ = pearsonr(hist1, hist2)
return corr
else:
return 0.0
except Exception as e:
print(f"Color histogram correlation computation failed: {e}")
return None
def compute_sharpness(self, frame):
"""Compute sharpness using Laplacian variance method"""
if frame is None:
return None
# Convert to grayscale if needed
gray = (
cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) if len(frame.shape) == 3 else frame
)
# Compute Laplacian variance (higher values = sharper)
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
sharpness = laplacian.var()
return sharpness
def compute_frame_metrics(self, frame1, frame2, frame_idx):
"""Compute all metrics for a frame pair"""
metrics = {
"frame_index": frame_idx,
"ssim": self.compute_ssim(frame1, frame2),
"psnr": self.compute_psnr(frame1, frame2),
"mse": self.compute_mse(frame1, frame2),
"phash": self.compute_phash(frame1, frame2),
"color_hist_corr": self.compute_color_histogram_correlation(frame1, frame2),
"sharpness1": self.compute_sharpness(frame1),
"sharpness2": self.compute_sharpness(frame2),
}
# Compute average sharpness for the pair
if metrics["sharpness1"] is not None and metrics["sharpness2"] is not None:
metrics["sharpness_avg"] = (
metrics["sharpness1"] + metrics["sharpness2"]
) / 2
metrics["sharpness_diff"] = abs(
metrics["sharpness1"] - metrics["sharpness2"]
)
else:
metrics["sharpness_avg"] = None
metrics["sharpness_diff"] = None
return metrics
def compute_all_metrics(self, frames1, frames2):
"""Compute metrics for all frame pairs"""
all_metrics = []
max_frames = max(len(frames1), len(frames2))
for i in range(max_frames):
frame1 = frames1[i] if i < len(frames1) else None
frame2 = frames2[i] if i < len(frames2) else None
if frame1 is not None or frame2 is not None:
metrics = self.compute_frame_metrics(frame1, frame2, i)
all_metrics.append(metrics)
else:
# Handle cases where both frames are missing
all_metrics.append(
{
"frame_index": i,
"ssim": None,
"ms_ssim": None,
"psnr": None,
"mse": None,
"phash": None,
"color_hist_corr": None,
"sharpness1": None,
"sharpness2": None,
"sharpness_avg": None,
"sharpness_diff": None,
}
)
return all_metrics
def get_metric_summary(self, metrics_list):
"""Compute summary statistics for all metrics"""
metric_names = [
"ssim",
"psnr",
"mse",
"phash",
"color_hist_corr",
"sharpness1",
"sharpness2",
"sharpness_avg",
"sharpness_diff",
]
summary = {
"total_frames": len(metrics_list),
"valid_frames": len([m for m in metrics_list if m.get("ssim") is not None]),
}
# Compute statistics for each metric
for metric_name in metric_names:
valid_values = [
m[metric_name] for m in metrics_list if m.get(metric_name) is not None
]
if valid_values:
summary.update(
{
f"{metric_name}_mean": np.mean(valid_values),
f"{metric_name}_min": np.min(valid_values),
f"{metric_name}_max": np.max(valid_values),
f"{metric_name}_std": np.std(valid_values),
}
)
return summary
def create_individual_metric_plots(self, metrics_list, current_frame=0):
"""Create individual plots for each metric with frame on x-axis"""
if not metrics_list:
return None
# Extract frame indices
frame_indices = [m["frame_index"] for m in metrics_list]
# Helper function to get valid data
def get_valid_data(metric_name):
values = [m.get(metric_name) for m in metrics_list]
valid_indices = [i for i, v in enumerate(values) if v is not None]
valid_values = [values[i] for i in valid_indices]
valid_frames = [frame_indices[i] for i in valid_indices]
return valid_frames, valid_values
# Create individual plots for each metric
plots = {}
# 1. SSIM Plot
ssim_frames, ssim_values = get_valid_data("ssim")
if ssim_values:
# Calculate dynamic y-axis range for SSIM to highlight differences
min_ssim = min(ssim_values)
max_ssim = max(ssim_values)
ssim_range = max_ssim - min_ssim
# If there's very little variation, zoom in to show differences
if ssim_range < 0.05:
# For small variations, zoom in to show differences better
center = (min_ssim + max_ssim) / 2
padding = max(
0.02, ssim_range * 2
) # At least 0.02 range or 2x actual range
y_min = max(0, center - padding)
y_max = min(1, center + padding)
else:
# For larger variations, add some padding
padding = ssim_range * 0.15 # 15% padding
y_min = max(0, min_ssim - padding)
y_max = min(1, max_ssim + padding)
fig_ssim = go.Figure()
# Add area fill to emphasize the curve
fig_ssim.add_trace(
go.Scatter(
x=ssim_frames,
y=[y_min] * len(ssim_frames),
mode="lines",
line=dict(
color="rgba(0,0,255,0)"
), # Transparent line for area base
showlegend=False,
hoverinfo="skip",
)
)
fig_ssim.add_trace(
go.Scatter(
x=ssim_frames,
y=ssim_values,
mode="lines+markers",
name="SSIM",
line=dict(color="blue", width=3),
marker=dict(
size=6, color="blue", line=dict(color="darkblue", width=1)
),
hovertemplate="Frame %{x}
SSIM: %{y:.5f}",
fill="tonexty",
fillcolor="rgba(0,0,255,0.1)", # Light blue fill
)
)
if current_frame is not None:
fig_ssim.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_ssim.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_ssim.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_ssim.update_yaxes(
title_text="SSIM",
range=[y_min, y_max],
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["ssim"] = fig_ssim
# 2. PSNR Plot
psnr_frames, psnr_values = get_valid_data("psnr")
if psnr_values:
fig_psnr = go.Figure()
fig_psnr.add_trace(
go.Scatter(
x=psnr_frames,
y=psnr_values,
mode="lines+markers",
name="PSNR",
line=dict(color="green", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
PSNR: %{y:.2f} dB",
)
)
if current_frame is not None:
fig_psnr.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_psnr.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_psnr.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_psnr.update_yaxes(
title_text="PSNR (dB)",
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["psnr"] = fig_psnr
# 3. MSE Plot
mse_frames, mse_values = get_valid_data("mse")
if mse_values:
fig_mse = go.Figure()
fig_mse.add_trace(
go.Scatter(
x=mse_frames,
y=mse_values,
mode="lines+markers",
name="MSE",
line=dict(color="red", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
MSE: %{y:.2f}",
)
)
if current_frame is not None:
fig_mse.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_mse.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_mse.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_mse.update_yaxes(
title_text="MSE", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
plots["mse"] = fig_mse
# 4. pHash Plot
phash_frames, phash_values = get_valid_data("phash")
if phash_values:
fig_phash = go.Figure()
fig_phash.add_trace(
go.Scatter(
x=phash_frames,
y=phash_values,
mode="lines+markers",
name="pHash",
line=dict(color="purple", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
pHash: %{y:.4f}",
)
)
if current_frame is not None:
fig_phash.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_phash.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_phash.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_phash.update_yaxes(
title_text="pHash Similarity",
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["phash"] = fig_phash
# 5. Color Histogram Correlation Plot
hist_frames, hist_values = get_valid_data("color_hist_corr")
if hist_values:
fig_hist = go.Figure()
fig_hist.add_trace(
go.Scatter(
x=hist_frames,
y=hist_values,
mode="lines+markers",
name="Color Histogram",
line=dict(color="orange", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
Color Histogram: %{y:.4f}",
)
)
if current_frame is not None:
fig_hist.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_hist.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_hist.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_hist.update_yaxes(
title_text="Color Histogram Correlation",
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["color_hist"] = fig_hist
# 6. Sharpness Comparison Plot
sharp1_frames, sharp1_values = get_valid_data("sharpness1")
sharp2_frames, sharp2_values = get_valid_data("sharpness2")
if sharp1_values or sharp2_values:
fig_sharp = go.Figure()
if sharp1_values:
fig_sharp.add_trace(
go.Scatter(
x=sharp1_frames,
y=sharp1_values,
mode="lines+markers",
name="Video 1",
line=dict(color="darkgreen", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
Video 1 Sharpness: %{y:.1f}",
)
)
if sharp2_values:
fig_sharp.add_trace(
go.Scatter(
x=sharp2_frames,
y=sharp2_values,
mode="lines+markers",
name="Video 2",
line=dict(color="darkblue", width=3),
marker=dict(size=6),
hovertemplate="Frame %{x}
Video 2 Sharpness: %{y:.1f}",
)
)
if current_frame is not None:
fig_sharp.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_sharp.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=True,
legend=dict(
orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5
),
dragmode=False,
hovermode="x unified",
)
fig_sharp.update_xaxes(
title_text="Frame", gridcolor="rgba(128,128,128,0.4)", fixedrange=True
)
fig_sharp.update_yaxes(
title_text="Sharpness",
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["sharpness"] = fig_sharp
# 7. Overall Quality Score Plot (Combination of metrics)
# Calculate overall quality score by combining normalized metrics
if ssim_values and psnr_values and len(ssim_values) == len(psnr_values):
# Get data for metrics that contribute to overall score
phash_frames_overall, phash_values_overall = get_valid_data("phash")
# Ensure we have the same frames for all metrics
common_frames = set(ssim_frames) & set(psnr_frames)
if phash_values_overall:
common_frames = common_frames & set(phash_frames_overall)
common_frames = sorted(list(common_frames))
if common_frames:
# Extract values for common frames
ssim_common = [
ssim_values[ssim_frames.index(f)]
for f in common_frames
if f in ssim_frames
]
psnr_common = [
psnr_values[psnr_frames.index(f)]
for f in common_frames
if f in psnr_frames
]
# Normalize PSNR to 0-1 scale using min-max normalization
if psnr_common:
psnr_min = min(psnr_common)
psnr_max = max(psnr_common)
if psnr_max > psnr_min:
psnr_normalized = [
(p - psnr_min) / (psnr_max - psnr_min) for p in psnr_common
]
else:
psnr_normalized = [0.0 for _ in psnr_common]
else:
psnr_normalized = []
# Start with SSIM and normalized PSNR
quality_components = [ssim_common, psnr_normalized]
component_names = ["SSIM", "PSNR"]
# Add pHash if available
if phash_values_overall:
phash_common = [
phash_values_overall[phash_frames_overall.index(f)]
for f in common_frames
if f in phash_frames_overall
]
if len(phash_common) == len(ssim_common):
quality_components.append(phash_common)
component_names.append("pHash")
# Calculate average across all components
overall_quality = []
for i in range(len(common_frames)):
frame_scores = [
component[i]
for component in quality_components
if i < len(component)
]
overall_quality.append(sum(frame_scores) / len(frame_scores))
# Calculate dynamic y-axis range to emphasize differences
min_quality = min(overall_quality)
max_quality = max(overall_quality)
quality_range = max_quality - min_quality
# If there's very little variation, use a smaller range to emphasize small differences
if quality_range < 0.08:
# For small variations, zoom in to show differences better
center = (min_quality + max_quality) / 2
padding = max(
0.04, quality_range * 2
) # At least 0.04 range or 2x the actual range
y_min = max(0, center - padding)
y_max = min(1, center + padding)
else:
# For larger variations, add some padding
padding = quality_range * 0.15 # 15% padding
y_min = max(0, min_quality - padding)
y_max = min(1, max_quality + padding)
fig_overall = go.Figure()
# Add area fill to emphasize the quality curve
fig_overall.add_trace(
go.Scatter(
x=common_frames,
y=[y_min] * len(common_frames),
mode="lines",
line=dict(
color="rgba(255,215,0,0)"
), # Transparent line for area base
showlegend=False,
hoverinfo="skip",
)
)
fig_overall.add_trace(
go.Scatter(
x=common_frames,
y=overall_quality,
mode="lines+markers",
name="Overall Quality",
line=dict(color="gold", width=4),
marker=dict(
size=8, color="gold", line=dict(color="orange", width=2)
),
hovertemplate="Frame %{x}
Overall Quality: %{y:.5f}
Combined from: "
+ ", ".join(component_names)
+ "",
fill="tonexty",
fillcolor="rgba(255,215,0,0.15)", # Semi-transparent gold fill
)
)
# Add quality threshold indicators if there are significant variations
if current_frame is not None:
fig_overall.add_vline(
x=current_frame,
line_dash="dash",
line_color="red",
line_width=2,
)
fig_overall.update_layout(
height=300,
margin=dict(t=20, b=40, l=60, r=20),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
dragmode=False,
hovermode="x unified",
)
fig_overall.update_xaxes(
title_text="Frame",
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
fig_overall.update_yaxes(
title_text="Overall Quality Score",
range=[y_min, y_max],
gridcolor="rgba(128,128,128,0.4)",
fixedrange=True,
)
plots["overall"] = fig_overall
return plots
def create_modern_plot(self, metrics_list, current_frame=0):
"""Create individual metric plots instead of combined dashboard"""
return self.create_individual_metric_plots(metrics_list, current_frame)
class VideoFrameComparator:
def __init__(self):
self.video1_frames = []
self.video2_frames = []
self.max_frames = 0
self.frame_metrics = FrameMetrics()
self.computed_metrics = []
self.metrics_summary = {}
def extract_frames(self, video_path):
"""Extract all frames from a video file or URL"""
if not video_path:
return []
# Check if it's a URL or local file
is_url = video_path.startswith(("http://", "https://"))
if not is_url and not os.path.exists(video_path):
print(f"Warning: Local video file not found: {video_path}")
return []
frames = []
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(
f"Error: Could not open video {'URL' if is_url else 'file'}: {video_path}"
)
return []
try:
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB for display
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
frame_count += 1
# Add progress feedback for URLs (which might be slower)
if is_url and frame_count % 30 == 0:
print(f"Processed {frame_count} frames from URL...")
except Exception as e:
print(f"Error processing video: {e}")
finally:
cap.release()
print(
f"Successfully extracted {len(frames)} frames from {'URL' if is_url else 'file'}: {video_path}"
)
return frames
def is_comparison_in_data_json(
self, video1_path, video2_path, json_file_path="data.json"
):
"""Check if this video comparison exists in data.json"""
try:
with open(json_file_path, "r") as f:
data = json.load(f)
for comparison in data.get("comparisons", []):
videos = comparison.get("videos", [])
if len(videos) == 2:
# Check both orders (works for both local files and URLs)
if (videos[0] == video1_path and videos[1] == video2_path) or (
videos[0] == video2_path and videos[1] == video1_path
):
return True
return False
except Exception:
return False
def load_videos(self, video1_path, video2_path):
"""Load both videos and extract frames"""
if not video1_path and not video2_path:
return "Please upload at least one video.", 0, None, None, "", None
# Extract frames from both videos
self.video1_frames = self.extract_frames(video1_path) if video1_path else []
self.video2_frames = self.extract_frames(video2_path) if video2_path else []
# Determine maximum number of frames
self.max_frames = max(len(self.video1_frames), len(self.video2_frames))
if self.max_frames == 0:
return (
"No valid frames found in the uploaded videos.",
0,
None,
None,
"",
None,
)
# Compute metrics if both videos are present and not in data.json
metrics_info = ""
plots = None
if (
video1_path
and video2_path
and not self.is_comparison_in_data_json(video1_path, video2_path)
):
print("Computing comprehensive frame-by-frame metrics...")
self.computed_metrics = self.frame_metrics.compute_all_metrics(
self.video1_frames, self.video2_frames
)
self.metrics_summary = self.frame_metrics.get_metric_summary(
self.computed_metrics
)
# Build metrics info string
metrics_info = "\n\nš Computed Metrics Summary:\n"
metric_display = {
"ssim": ("SSIM", ".4f", "", "ā Higher=Better"),
"psnr": ("PSNR", ".2f", " dB", "ā Higher=Better"),
"mse": ("MSE", ".2f", "", "ā Lower=Better"),
"phash": ("pHash", ".4f", "", "ā Higher=Better"),
"color_hist_corr": ("Color Hist", ".4f", "", "ā Higher=Better"),
"sharpness_avg": ("Sharpness", ".1f", "", "ā Higher=Better"),
}
for metric_key, (
display_name,
format_str,
unit,
direction,
) in metric_display.items():
if self.metrics_summary.get(f"{metric_key}_mean") is not None:
mean_val = self.metrics_summary[f"{metric_key}_mean"]
std_val = self.metrics_summary[f"{metric_key}_std"]
metrics_info += f"{display_name}: μ={mean_val:{format_str}}{unit}, Ļ={std_val:{format_str}}{unit} ({direction})\n"
metrics_info += f"Valid Frames: {self.metrics_summary['valid_frames']}/{self.metrics_summary['total_frames']}"
# Generate initial plot
plots = self.frame_metrics.create_individual_metric_plots(
self.computed_metrics, 0
)
else:
self.computed_metrics = []
self.metrics_summary = {}
if video1_path and video2_path:
metrics_info = "\n\nš Note: This comparison is predefined in data.json (metrics not computed)"
# Get initial frames
frame1 = (
self.video1_frames[0]
if self.video1_frames
else np.zeros((480, 640, 3), dtype=np.uint8)
)
frame2 = (
self.video2_frames[0]
if self.video2_frames
else np.zeros((480, 640, 3), dtype=np.uint8)
)
status_msg = "Videos loaded successfully!\n"
status_msg += f"Video 1: {len(self.video1_frames)} frames\n"
status_msg += f"Video 2: {len(self.video2_frames)} frames\n"
status_msg += (
f"Use the slider to navigate through frames (0-{self.max_frames - 1})"
)
status_msg += metrics_info
return (
status_msg,
self.max_frames - 1,
frame1,
frame2,
self.get_current_frame_info(0),
plots,
)
def get_frames_at_index(self, frame_index):
"""Get frames at specific index from both videos"""
frame_index = int(frame_index)
# Get frame from video 1
if frame_index < len(self.video1_frames):
frame1 = self.video1_frames[frame_index]
else:
# Create a placeholder if frame doesn't exist
frame1 = np.zeros((480, 640, 3), dtype=np.uint8)
cv2.putText(
frame1,
f"Frame {frame_index} not available",
(50, 240),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
)
# Get frame from video 2
if frame_index < len(self.video2_frames):
frame2 = self.video2_frames[frame_index]
else:
# Create a placeholder if frame doesn't exist
frame2 = np.zeros((480, 640, 3), dtype=np.uint8)
cv2.putText(
frame2,
f"Frame {frame_index} not available",
(50, 240),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2,
)
return frame1, frame2
def get_current_frame_info(self, frame_index):
"""Get information about the current frame including metrics"""
frame_index = int(frame_index)
info = f"Current Frame: {frame_index} / {self.max_frames - 1}"
# Add metrics info if available
if self.computed_metrics and frame_index < len(self.computed_metrics):
metrics = self.computed_metrics[frame_index]
# === COMPARISON METRICS (Between Videos) ===
comparison_metrics = []
# SSIM with quality assessment
if metrics.get("ssim") is not None:
ssim_val = metrics["ssim"]
if ssim_val >= 0.9:
quality = "š¢ Excellent"
elif ssim_val >= 0.8:
quality = "šµ Good"
elif ssim_val >= 0.6:
quality = "š” Fair"
else:
quality = "š“ Poor"
comparison_metrics.append(
f"SSIM: {ssim_val:.4f} ({quality} similarity)"
)
# PSNR with quality indicator
if metrics.get("psnr") is not None:
psnr_val = metrics["psnr"]
if psnr_val >= 40:
psnr_quality = "š¢ Excellent"
elif psnr_val >= 30:
psnr_quality = "šµ Good"
elif psnr_val >= 20:
psnr_quality = "š” Fair"
else:
psnr_quality = "š“ Poor"
comparison_metrics.append(
f"PSNR: {psnr_val:.1f}dB ({psnr_quality} signal quality)"
)
# MSE with quality indicator (lower is better)
if metrics.get("mse") is not None:
mse_val = metrics["mse"]
if mse_val <= 50:
mse_quality = "š¢ Very Similar"
elif mse_val <= 100:
mse_quality = "šµ Similar"
elif mse_val <= 200:
mse_quality = "š” Moderately Different"
else:
mse_quality = "š“ Very Different"
comparison_metrics.append(f"MSE: {mse_val:.1f} ({mse_quality})")
# pHash with quality indicator
if metrics.get("phash") is not None:
phash_val = metrics["phash"]
if phash_val >= 0.95:
phash_quality = "š¢ Nearly Identical"
elif phash_val >= 0.9:
phash_quality = "šµ Very Similar"
elif phash_val >= 0.8:
phash_quality = "š” Somewhat Similar"
else:
phash_quality = "š“ Different"
comparison_metrics.append(
f"pHash: {phash_val:.3f} ({phash_quality} perceptually)"
)
# Color Histogram Correlation
if metrics.get("color_hist_corr") is not None:
color_val = metrics["color_hist_corr"]
if color_val >= 0.9:
color_quality = "š¢ Very Similar Colors"
elif color_val >= 0.8:
color_quality = "šµ Similar Colors"
elif color_val >= 0.6:
color_quality = "š” Moderate Color Diff"
else:
color_quality = "š“ Different Colors"
comparison_metrics.append(f"Color: {color_val:.3f} ({color_quality})")
# Add comparison metrics to info
if comparison_metrics:
info += "\nš Comparison Analysis: " + " | ".join(comparison_metrics)
# === INDIVIDUAL VIDEO QUALITY ===
individual_metrics = []
# Individual Sharpness for each video
if metrics.get("sharpness1") is not None:
sharp1 = metrics["sharpness1"]
if sharp1 >= 200:
sharp1_quality = "š¢ Sharp"
elif sharp1 >= 100:
sharp1_quality = "šµ Moderate"
elif sharp1 >= 50:
sharp1_quality = "š” Soft"
else:
sharp1_quality = "š“ Blurry"
individual_metrics.append(
f"V1 Sharpness: {sharp1:.0f} ({sharp1_quality})"
)
if metrics.get("sharpness2") is not None:
sharp2 = metrics["sharpness2"]
if sharp2 >= 200:
sharp2_quality = "š¢ Sharp"
elif sharp2 >= 100:
sharp2_quality = "šµ Moderate"
elif sharp2 >= 50:
sharp2_quality = "š” Soft"
else:
sharp2_quality = "š“ Blurry"
individual_metrics.append(
f"V2 Sharpness: {sharp2:.0f} ({sharp2_quality})"
)
# Sharpness comparison
if (
metrics.get("sharpness1") is not None
and metrics.get("sharpness2") is not None
):
sharp1 = metrics["sharpness1"]
sharp2 = metrics["sharpness2"]
# Calculate difference percentage
diff_pct = abs(sharp1 - sharp2) / max(sharp1, sharp2) * 100
# Determine significance with clearer labels
if diff_pct > 20:
significance = "š“ MAJOR difference"
elif diff_pct > 10:
significance = "š” MODERATE difference"
elif diff_pct > 5:
significance = "šµ MINOR difference"
else:
significance = "š¢ NEGLIGIBLE difference"
# Determine which is sharper
if sharp1 > sharp2:
comparison = "V1 is sharper"
elif sharp2 > sharp1:
comparison = "V2 is sharper"
else:
comparison = "Equal sharpness"
individual_metrics.append(f"Sharpness: {comparison} ({significance})")
# Add individual metrics to info
if individual_metrics:
info += "\nšÆ Individual Quality: " + " | ".join(individual_metrics)
# === OVERALL QUALITY ASSESSMENT ===
# Calculate combined quality score from multiple metrics
quality_score = 0
quality_count = 0
metric_contributions = []
# SSIM contribution
if metrics.get("ssim") is not None:
quality_score += metrics["ssim"]
quality_count += 1
metric_contributions.append(f"SSIM({metrics['ssim']:.3f})")
# PSNR contribution (normalized to 0-1 scale)
if metrics.get("psnr") is not None:
psnr_norm = min(metrics["psnr"] / 50, 1.0)
quality_score += psnr_norm
quality_count += 1
metric_contributions.append(f"PSNR({psnr_norm:.3f})")
# pHash contribution
if metrics.get("phash") is not None:
quality_score += metrics["phash"]
quality_count += 1
metric_contributions.append(f"pHash({metrics['phash']:.3f})")
if quality_count > 0:
avg_quality = quality_score / quality_count
# Add overall assessment with formula explanation
if avg_quality >= 0.9:
overall = "⨠Excellent Overall"
quality_indicator = "š¢"
elif avg_quality >= 0.8:
overall = "ā
Good Overall"
quality_indicator = "šµ"
elif avg_quality >= 0.6:
overall = "ā ļø Fair Overall"
quality_indicator = "š”"
else:
overall = "ā Poor Overall"
quality_indicator = "š“"
# Calculate quality variation across all frames to show differences
quality_variation = ""
if self.computed_metrics and len(self.computed_metrics) > 1:
# Calculate overall quality for all frames to show variation
all_quality_scores = []
for metric in self.computed_metrics:
frame_quality = 0
frame_quality_count = 0
if metric.get("ssim") is not None:
frame_quality += metric["ssim"]
frame_quality_count += 1
if metric.get("psnr") is not None:
frame_quality += min(metric["psnr"] / 50, 1.0)
frame_quality_count += 1
if metric.get("phash") is not None:
frame_quality += metric["phash"]
frame_quality_count += 1
if frame_quality_count > 0:
all_quality_scores.append(
frame_quality / frame_quality_count
)
if len(all_quality_scores) > 1:
min_qual = min(all_quality_scores)
max_qual = max(all_quality_scores)
variation = max_qual - min_qual
if variation > 0.08:
quality_variation = (
f" | š High Variation (Ī{variation:.4f})"
)
elif variation > 0.04:
quality_variation = (
f" | š Moderate Variation (Ī{variation:.4f})"
)
elif variation > 0.02:
quality_variation = (
f" | š Low Variation (Ī{variation:.4f})"
)
else:
quality_variation = (
f" | š Stable Quality (Ī{variation:.4f})"
)
info += f"\nšÆ Overall Quality: {quality_indicator} {avg_quality:.5f} ({overall}){quality_variation}"
info += f"\n š” Formula: Average of {' + '.join(metric_contributions)} = {avg_quality:.5f}"
return info
def get_updated_plot(self, frame_index):
"""Get updated plot with current frame highlighted"""
if self.computed_metrics:
return self.frame_metrics.create_individual_metric_plots(
self.computed_metrics, int(frame_index)
)
return None
def load_examples_from_json(json_file_path="data.json"):
"""Load example video pairs from JSON configuration file"""
try:
with open(json_file_path, "r") as f:
data = json.load(f)
examples = []
# Extract video pairs from the comparisons
for comparison in data.get("comparisons", []):
videos = comparison.get("videos", [])
# Validate that video files/URLs exist or are accessible
valid_videos = []
for video_path in videos:
if video_path: # Check if not empty/None
# Check if it's a URL
if video_path.startswith(("http://", "https://")):
# For URLs, we'll assume they're valid (can't easily check without downloading)
# OpenCV will handle the validation during actual loading
valid_videos.append(video_path)
print(f"Added video URL: {video_path}")
else:
# Convert to absolute path for local files
abs_path = os.path.abspath(video_path)
if os.path.exists(abs_path):
valid_videos.append(abs_path)
print(f"Added local video file: {abs_path}")
elif os.path.exists(video_path):
# Try relative path as fallback
valid_videos.append(video_path)
print(f"Added local video file: {video_path}")
else:
print(
f"Warning: Local video file not found: {video_path} (abs: {abs_path})"
)
# Add to examples if we have valid videos
if len(valid_videos) == 2:
examples.append(valid_videos)
elif len(valid_videos) == 1:
# Single video example (compare with None)
examples.append([valid_videos[0], None])
return examples
except FileNotFoundError:
print(f"Warning: {json_file_path} not found. No examples will be loaded.")
return []
except json.JSONDecodeError as e:
print(f"Error parsing {json_file_path}: {e}")
return []
except Exception as e:
print(f"Error loading examples: {e}")
return []
def get_all_videos_from_json(json_file_path="data.json"):
"""Get list of all unique videos mentioned in the JSON file"""
try:
with open(json_file_path, "r") as f:
data = json.load(f)
all_videos = set()
# Extract all unique video paths/URLs from comparisons
for comparison in data.get("comparisons", []):
videos = comparison.get("videos", [])
for video_path in videos:
if video_path: # Only add non-empty paths
# Check if it's a URL or local file
if video_path.startswith(("http://", "https://")):
# For URLs, add them directly
all_videos.add(video_path)
elif os.path.exists(video_path):
# For local files, check existence before adding
all_videos.add(video_path)
return sorted(list(all_videos))
except FileNotFoundError:
print(f"Warning: {json_file_path} not found.")
return []
except json.JSONDecodeError as e:
print(f"Error parsing {json_file_path}: {e}")
return []
except Exception as e:
print(f"Error loading videos: {e}")
return []
def create_app():
comparator = VideoFrameComparator()
example_pairs = load_examples_from_json()
print(f"DEBUG: Loaded {len(example_pairs)} example pairs")
for i, pair in enumerate(example_pairs):
print(f" Example {i + 1}: {pair}")
with gr.Blocks(
title="Frame Arena - Video Frame Comparator",
# theme=gr.themes.Soft(),
fill_width=True,
css="""
/* Ensure plots adapt to theme */
.plotly .main-svg {
color: var(--body-text-color, #000) !important;
}
/* Grid visibility for both themes */
.plotly .gridlayer .xgrid, .plotly .gridlayer .ygrid {
stroke-opacity: 0.4 !important;
}
/* Axis text color adaptation */
.plotly .xtick text, .plotly .ytick text {
fill: var(--body-text-color, #000) !important;
}
/* Axis title color adaptation - multiple selectors for better coverage */
.plotly .g-xtitle, .plotly .g-ytitle,
.plotly .xtitle, .plotly .ytitle,
.plotly text[class*="xtitle"], .plotly text[class*="ytitle"],
.plotly .infolayer .g-xtitle, .plotly .infolayer .g-ytitle {
fill: var(--body-text-color, #000) !important;
}
/* Additional axis title selectors */
.plotly .subplot .xtitle, .plotly .subplot .ytitle,
.plotly .cartesianlayer .xtitle, .plotly .cartesianlayer .ytitle {
fill: var(--body-text-color, #000) !important;
}
/* SVG text elements in plots */
.plotly svg text {
fill: var(--body-text-color, #000) !important;
}
/* Legend text color */
.plotly .legendtext, .plotly .legend text {
fill: var(--body-text-color, #000) !important;
}
/* Hover label adaptation */
.plotly .hoverlayer .hovertext, .plotly .hovertext {
fill: var(--body-text-color, #000) !important;
color: var(--body-text-color, #000) !important;
}
/* Annotation text */
.plotly .annotation-text, .plotly .annotation {
fill: var(--body-text-color, #000) !important;
}
/* Disable plot interactions except hover */
.plotly .modebar {
display: none !important;
}
.plotly .plot-container .plotly {
pointer-events: none !important;
}
.plotly .plot-container .plotly .hoverlayer {
pointer-events: auto !important;
}
.plotly .plot-container .plotly .hovertext {
pointer-events: auto !important;
}
""",
) as app:
gr.Markdown("""
# š¬ Frame Arena: Frame by frame comparisons of any videos
> š This tool has been created to celebrate our Wan 2.2 [text-to-video](https://replicate.com/wan-video/wan-2.2-t2v-480p-fast) and [image-to-video](https://replicate.com/wan-video/wan-2.2-i2v-a14b) endpoints on Replicate. Want to know more? Check out [our blog](https://www.wan22.com/blog/video-optimization-on-replicate)!
- Upload videos in common formats with the same number of frames (MP4, AVI, MOV, etc.) or use URLs
- **7 Quality Metrics**: SSIM, PSNR, MSE, pHash, Color Histogram, Sharpness + Overall Quality
- **Individual Visualization**: Each metric gets its own dedicated plot
- **Real-time Analysis**: Navigate frames with live metric updates
- **Smart Comparisons**: Understand differences between videos per metric
**Perfect for**: Analyzing compression effects, processing artifacts, visual quality assessment, and compression algorithm comparisons.
""")
with gr.Row():
with gr.Column():
gr.Markdown("### Video 1")
video1_input = gr.File(
label="Upload Video 1",
file_types=[
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
],
type="filepath",
)
with gr.Column():
gr.Markdown("### Video 2")
video2_input = gr.File(
label="Upload Video 2",
file_types=[
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
],
type="filepath",
)
# Add examples at the top for better UX
if example_pairs:
gr.Markdown("### š Example Video Comparisons")
gr.Examples(
examples=example_pairs,
inputs=[video1_input, video2_input],
label="Click any example to load video pairs:",
examples_per_page=10,
run_on_click=False, # We'll handle this manually
)
load_btn = gr.Button("š Load Videos", variant="primary", size="lg")
# Frame comparison section (initially hidden)
frame_display = gr.Row(visible=True)
with frame_display:
with gr.Column():
gr.Markdown("### Video 1 - Current Frame")
frame1_output = gr.Image(
label="Video 1 Frame",
type="numpy",
interactive=False,
# height=400,
)
with gr.Column():
gr.Markdown("### Frame Slider (Left: Video 1, Right: Video 2)")
image_slider = ImageSlider(
label="Drag to compare frames",
type="numpy",
interactive=True,
# height=400,
)
with gr.Column():
gr.Markdown("### Video 2 - Current Frame")
frame2_output = gr.Image(
label="Video 2 Frame",
type="numpy",
interactive=False,
# height=400,
)
# Frame navigation (initially hidden) - moved underneath frames
frame_controls = gr.Row(visible=True)
with frame_controls:
frame_slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Frame Number",
interactive=True,
)
# Comprehensive metrics visualization (initially hidden)
metrics_section = gr.Row(visible=True)
with metrics_section:
with gr.Column():
gr.Markdown("### š Metric Analysis")
# Overall quality plot
with gr.Row():
overall_plot = gr.Plot(
label="Overall Quality (Combined Metric [SSIM + normalized_PSNR + pHash])",
show_label=True,
)
# Frame info moved below overall quality plot
frame_info = gr.Textbox(
label="Frame Information & Metrics",
interactive=False,
value="",
lines=3,
)
# Add comprehensive usage guide underneath frame information & metrics
with gr.Accordion("š Usage Guide & Metrics Reference", open=False):
with gr.Row():
with gr.Column():
gr.Markdown("""
### š Metrics Explained
- **SSIM**: Structural Similarity (1.0 = identical structure, 0.0 = completely different)
- **PSNR**: Peak Signal-to-Noise Ratio in dB (higher = better quality, less noise)
- **MSE**: Mean Squared Error (lower = more similar pixel values)
- **pHash**: Perceptual Hash similarity (1.0 = visually identical)
- **Color Histogram**: Color distribution correlation (1.0 = identical color patterns)
- **Sharpness**: Laplacian variance per video (higher = sharper/more detailed images)
- **Overall Quality**: Combined metric averaging SSIM, min-max normalized PSNR, and pHash
""")
with gr.Column() as info_section:
status_output = gr.Textbox(
label="Status", interactive=False, lines=16
)
with gr.Row():
with gr.Column():
gr.Markdown("""
### šÆ Quality Assessment Scale (Research-Based Thresholds)
**SSIM Scale** (based on human perception studies):
- š¢ **Excellent (ā„0.9)**: Visually indistinguishable differences
- šµ **Good (ā„0.8)**: Minor visible differences, still high quality
- š” **Fair (ā„0.6)**: Noticeable differences, acceptable quality
- š“ **Poor (<0.6)**: Significant visible artifacts and differences
**PSNR Scale** (standard video quality benchmarks):
- š¢ **Excellent (ā„40dB)**: Professional broadcast quality
- šµ **Good (ā„30dB)**: High consumer video quality
- š” **Fair (ā„20dB)**: Acceptable for web streaming
- š“ **Poor (<20dB)**: Low quality with visible compression artifacts
**MSE Scale** (pixel difference thresholds):
- š¢ **Very Similar (ā¤50)**: Minimal pixel-level differences
- šµ **Similar (ā¤100)**: Small differences, good quality preservation
- š” **Moderately Different (ā¤200)**: Noticeable but acceptable differences
- š“ **Very Different (>200)**: Significant pixel-level changes
""")
with gr.Column():
gr.Markdown("""
### š Understanding Comparisons
**Comparison Analysis**: Shows how similar/different the videos are
- Most metrics indicate **similarity** - not which video "wins"
- Higher SSIM/PSNR/pHash/Color = more similar videos
- Lower MSE = more similar videos
**Individual Quality**: Shows the quality of each video separately
- Sharpness comparison shows which video has more detail
- Significance levels: š“ MAJOR (>20%), š” MODERATE (10-20%), šµ MINOR (5-10%), š¢ NEGLIGIBLE (<5%)
**Overall Quality**: Combines multiple metrics to provide a single similarity score
- **Formula**: Average of [SSIM + normalized_PSNR + pHash]
- **PSNR Normalization**: PSNR values are divided by 50dB and capped at 1.0
- **Range**: 0.0 to 1.0 (higher = more similar videos overall)
- **Purpose**: Provides a single metric when you need one overall assessment
- **Limitation**: Different metrics may disagree; check individual metrics for details
""")
# Individual metric plots
with gr.Row():
ssim_plot = gr.Plot(label="SSIM", show_label=True)
psnr_plot = gr.Plot(label="PSNR", show_label=True)
with gr.Row():
mse_plot = gr.Plot(label="MSE", show_label=True)
phash_plot = gr.Plot(label="pHash", show_label=True)
with gr.Row():
color_plot = gr.Plot(label="Color Histogram", show_label=True)
sharpness_plot = gr.Plot(label="Sharpness", show_label=True)
# Connect examples to auto-loading
if example_pairs:
# Use a manual approach to handle examples click
def examples_manual_handler(video1, video2):
print("DEBUG: Examples clicked manually!")
return load_videos_handler(video1, video2)
# Since we can't directly attach to examples, we'll use the change events
# Event handlers
def load_videos_handler(video1, video2):
print(
f"DEBUG: load_videos_handler called with video1={video1}, video2={video2}"
)
status, max_frames, frame1, frame2, info, plots = comparator.load_videos(
video1, video2
)
# Update slider
slider_update = gr.Slider(
minimum=0,
maximum=max_frames,
step=1,
value=0,
interactive=True if max_frames > 0 else False,
)
# Show/hide sections based on whether videos were loaded successfully
videos_loaded = max_frames > 0
# Extract individual plots from the plots dictionary
ssim_fig = plots.get("ssim") if plots else None
psnr_fig = plots.get("psnr") if plots else None
mse_fig = plots.get("mse") if plots else None
phash_fig = plots.get("phash") if plots else None
color_fig = plots.get("color_hist") if plots else None
sharpness_fig = plots.get("sharpness") if plots else None
overall_fig = plots.get("overall") if plots else None
return (
status, # status_output
slider_update, # frame_slider
frame1, # frame1_output
(frame1, frame2), # image_slider
frame2, # frame2_output
info, # frame_info
ssim_fig, # ssim_plot
psnr_fig, # psnr_plot
mse_fig, # mse_plot
phash_fig, # phash_plot
color_fig, # color_plot
sharpness_fig, # sharpness_plot
overall_fig, # overall_plot
gr.Row(visible=videos_loaded), # frame_controls
gr.Row(visible=videos_loaded), # frame_display
gr.Row(visible=videos_loaded and plots is not None), # metrics_section
gr.Row(visible=videos_loaded), # info_section
)
def update_frames(frame_index):
if comparator.max_frames == 0:
return (
None,
None,
None,
"No videos loaded",
None,
None,
None,
None,
None,
None,
None,
)
frame1, frame2 = comparator.get_frames_at_index(frame_index)
info = comparator.get_current_frame_info(frame_index)
plots = comparator.get_updated_plot(frame_index)
# Extract individual plots from the plots dictionary
ssim_fig = plots.get("ssim") if plots else None
psnr_fig = plots.get("psnr") if plots else None
mse_fig = plots.get("mse") if plots else None
phash_fig = plots.get("phash") if plots else None
color_fig = plots.get("color_hist") if plots else None
sharpness_fig = plots.get("sharpness") if plots else None
overall_fig = plots.get("overall") if plots else None
return (
frame1,
(frame1, frame2),
frame2,
info,
ssim_fig,
psnr_fig,
mse_fig,
phash_fig,
color_fig,
sharpness_fig,
overall_fig,
)
# Auto-load when examples populate the inputs
def auto_load_when_examples_change(video1, video2):
print(
f"DEBUG: auto_load_when_examples_change called with video1={video1}, video2={video2}"
)
# Only auto-load if both inputs are provided (from examples)
if video1 and video2:
print("DEBUG: Both videos present, calling load_videos_handler")
return load_videos_handler(video1, video2)
# If only one or no videos, return default empty state
print("DEBUG: Not both videos present, returning default state")
return (
"Please upload videos or select an example", # status_output
gr.Slider(
minimum=0, maximum=0, step=1, value=0, interactive=False
), # frame_slider
None, # frame1_output
(None, None), # image_slider
None, # frame2_output
"", # frame_info
None, # ssim_plot
None, # psnr_plot
None, # mse_plot
None, # phash_plot
None, # color_plot
None, # sharpness_plot
None, # overall_plot
gr.Row(visible=True), # frame_controls
gr.Row(visible=True), # frame_display
gr.Row(visible=True), # metrics_section
gr.Row(visible=True), # info_section
)
# Enhanced auto-load function with debouncing to prevent multiple rapid calls
last_processed_pair = {"video1": None, "video2": None}
def enhanced_auto_load(video1, video2):
print(f"DEBUG: Input change detected! video1={video1}, video2={video2}")
# Simple debouncing: skip if same video pair was just processed
if (
last_processed_pair["video1"] == video1
and last_processed_pair["video2"] == video2
):
print("DEBUG: Same video pair already processed, skipping...")
# Return current state without recomputing
return (
gr.update(), # status_output
gr.update(), # frame_slider
gr.update(), # frame1_output
gr.update(), # image_slider
gr.update(), # frame2_output
gr.update(), # frame_info
gr.update(), # ssim_plot
gr.update(), # psnr_plot
gr.update(), # mse_plot
gr.update(), # phash_plot
gr.update(), # color_plot
gr.update(), # sharpness_plot
gr.update(), # overall_plot
gr.update(), # frame_controls
gr.update(), # frame_display
gr.update(), # metrics_section
gr.update(), # info_section
)
last_processed_pair["video1"] = video1
last_processed_pair["video2"] = video2
return auto_load_when_examples_change(video1, video2)
# Auto-load when both video inputs change (triggered by examples)
video1_input.change(
fn=enhanced_auto_load,
inputs=[video1_input, video2_input],
outputs=[
status_output,
frame_slider,
frame1_output,
image_slider,
frame2_output,
frame_info,
ssim_plot,
psnr_plot,
mse_plot,
phash_plot,
color_plot,
sharpness_plot,
overall_plot,
frame_controls,
frame_display,
metrics_section,
info_section,
],
)
video2_input.change(
fn=enhanced_auto_load,
inputs=[video1_input, video2_input],
outputs=[
status_output,
frame_slider,
frame1_output,
image_slider,
frame2_output,
frame_info,
ssim_plot,
psnr_plot,
mse_plot,
phash_plot,
color_plot,
sharpness_plot,
overall_plot,
frame_controls,
frame_display,
metrics_section,
info_section,
],
)
# Manual load button event handler with debug
def debug_load_videos_handler(video1, video2):
print(f"DEBUG: Load button clicked! video1={video1}, video2={video2}")
return load_videos_handler(video1, video2)
load_btn.click(
fn=debug_load_videos_handler,
inputs=[video1_input, video2_input],
outputs=[
status_output,
frame_slider,
frame1_output,
image_slider,
frame2_output,
frame_info,
ssim_plot,
psnr_plot,
mse_plot,
phash_plot,
color_plot,
sharpness_plot,
overall_plot,
frame_controls,
frame_display,
metrics_section,
info_section,
],
)
frame_slider.change(
fn=update_frames,
inputs=[frame_slider],
outputs=[
frame1_output,
image_slider,
frame2_output,
frame_info,
ssim_plot,
psnr_plot,
mse_plot,
phash_plot,
color_plot,
sharpness_plot,
overall_plot,
],
)
return app
def main():
app = create_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)
if __name__ == "__main__":
main()