Spaces:
Sleeping
Sleeping
File size: 5,000 Bytes
8d272fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Visualization utilities for LLaVA.
"""
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
from typing import List, Dict, Tuple, Optional, Union
import cv2
def display_image_with_caption(
image_path: str,
caption: str,
figsize: Tuple[int, int] = (10, 10)
) -> None:
"""
Display an image with a caption.
Args:
image_path: Path to the image file
caption: Caption text
figsize: Figure size
"""
image = Image.open(image_path).convert('RGB')
plt.figure(figsize=figsize)
plt.imshow(image)
plt.axis('off')
plt.title(caption)
plt.tight_layout()
plt.show()
def visualize_attention(
image_path: str,
attention_weights: torch.Tensor,
figsize: Tuple[int, int] = (15, 5)
) -> None:
"""
Visualize attention weights on an image.
Args:
image_path: Path to the image file
attention_weights: Attention weights tensor
figsize: Figure size
"""
# Load image
image = Image.open(image_path).convert('RGB')
image_np = np.array(image)
# Normalize attention weights
if attention_weights.dim() > 2:
# Average across heads and layers if necessary
attention_weights = attention_weights.mean(dim=(0, 1))
attention_weights = attention_weights.detach().cpu().numpy()
attention_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min())
# Resize attention map to image size
attention_map = cv2.resize(attention_weights, (image.width, image.height))
# Create heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Overlay heatmap on image
alpha = 0.5
overlay = heatmap * alpha + image_np * (1 - alpha)
overlay = overlay.astype(np.uint8)
# Display original image and attention overlay
fig, axes = plt.subplots(1, 3, figsize=figsize)
axes[0].imshow(image_np)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(heatmap)
axes[1].set_title('Attention Map')
axes[1].axis('off')
axes[2].imshow(overlay)
axes[2].set_title('Overlay')
axes[2].axis('off')
plt.tight_layout()
plt.show()
def create_comparison_grid(
image_path: str,
responses: List[Dict[str, str]],
output_path: Optional[str] = None,
figsize: Tuple[int, int] = (12, 10)
) -> None:
"""
Create a comparison grid of different model responses.
Args:
image_path: Path to the image file
responses: List of dictionaries with 'model' and 'response' keys
output_path: Optional path to save the figure
figsize: Figure size
"""
# Load image
image = Image.open(image_path).convert('RGB')
# Create figure
fig = plt.figure(figsize=figsize)
# Add image
ax1 = plt.subplot2grid((len(responses) + 1, 3), (0, 0), colspan=3)
ax1.imshow(image)
ax1.set_title('Input Image')
ax1.axis('off')
# Add responses
for i, resp in enumerate(responses):
ax = plt.subplot2grid((len(responses) + 1, 3), (i + 1, 0), colspan=3)
ax.text(0.5, 0.5, f"{resp['model']}: {resp['response']}",
wrap=True, horizontalalignment='center',
verticalalignment='center', fontsize=10)
ax.axis('off')
plt.tight_layout()
# Save figure if output path is provided
if output_path:
plt.savefig(output_path, bbox_inches='tight')
plt.show()
def add_caption_to_image(
image_path: str,
caption: str,
output_path: str,
font_size: int = 20,
font_color: Tuple[int, int, int] = (255, 255, 255),
bg_color: Tuple[int, int, int] = (0, 0, 0)
) -> None:
"""
Add a caption to an image and save it.
Args:
image_path: Path to the input image
caption: Caption text
output_path: Path to save the output image
font_size: Font size
font_color: Font color (RGB)
bg_color: Background color (RGB)
"""
# Load image
image = Image.open(image_path).convert('RGB')
# Create a new image with space for the caption
caption_height = font_size + 20 # Add some padding
new_image = Image.new('RGB', (image.width, image.height + caption_height), bg_color)
new_image.paste(image, (0, 0))
# Add caption
draw = ImageDraw.Draw(new_image)
try:
font = ImageFont.truetype("arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()
# Calculate text position
text_width = draw.textlength(caption, font=font)
text_position = ((image.width - text_width) // 2, image.height + 10)
# Draw text
draw.text(text_position, caption, font=font, fill=font_color)
# Save image
new_image.save(output_path) |