Spaces:
Paused
Paused
import os | |
import tempfile | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import cv2 | |
from diffusers import DiffusionPipeline | |
import cupy as cp | |
from cupyx.scipy.ndimage import label as cp_label | |
from cupyx.scipy.ndimage import binary_dilation | |
from sklearn.cluster import DBSCAN | |
import trimesh | |
class GPUSatelliteModelGenerator: | |
def __init__(self, building_height=0.05): | |
self.building_height = building_height | |
# Move color arrays to GPU using cupy | |
self.shadow_colors = cp.array([ | |
[31, 42, 76], | |
[58, 64, 92], | |
[15, 27, 56], | |
[21, 22, 50], | |
[76, 81, 99] | |
]) | |
self.road_colors = cp.array([ | |
[187, 182, 175], | |
[138, 138, 138], | |
[142, 142, 129], | |
[202, 199, 189] | |
]) | |
self.water_colors = cp.array([ | |
[167, 225, 217], | |
[67, 101, 97], | |
[53, 83, 84], | |
[47, 94, 100], | |
[73, 131, 135] | |
]) | |
# Output colors (BGR for OpenCV) | |
self.roof_colors = cp.array([ | |
[191, 148, 124], | |
[190, 142, 121], | |
[184, 154, 139], | |
[178, 118, 118], | |
[164, 109, 107], | |
[155, 113, 105], | |
[153, 111, 106], | |
[155, 95, 96], | |
[135, 82, 87], | |
[117, 82, 78], | |
[113, 62, 50], | |
[166, 144, 135] | |
]) | |
# Convert roof colors to HSV | |
self.roof_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.roof_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
# Normalize roof HSV values | |
self.roof_colors_hsv[:, 0] = self.roof_colors_hsv[:, 0] * 2 | |
self.roof_colors_hsv[:, 1:] = self.roof_colors_hsv[:, 1:] / 255 | |
# Add roof tolerance (tighter than terrain to avoid confusion) | |
self.roof_tolerance = { | |
'hue': 8, # Tighter hue tolerance to differentiate from terrain | |
'sat': 0.15, | |
'val': 0.15 | |
} | |
# Convert reference colors to HSV on GPU | |
self.shadow_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.shadow_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
self.road_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.road_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
self.water_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.water_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
# Normalize HSV values on GPU | |
for colors_hsv in [self.shadow_colors_hsv, self.road_colors_hsv, self.water_colors_hsv]: | |
colors_hsv[:, 0] = colors_hsv[:, 0] * 2 | |
colors_hsv[:, 1:] = colors_hsv[:, 1:] / 255 | |
# Color tolerances | |
self.shadow_tolerance = {'hue': 15, 'sat': 0.15, 'val': 0.12} | |
self.road_tolerance = {'hue': 10, 'sat': 0.12, 'val': 0.15} | |
self.water_tolerance = {'hue': 20, 'sat': 0.15, 'val': 0.20} | |
# Colors dictionary in [B, G, R] | |
self.colors = { | |
'black': cp.array([0, 0, 0]), # Shadows | |
'blue': cp.array([255, 0, 0]), # Water | |
'green': cp.array([0, 255, 0]), # Vegetation | |
'gray': cp.array([128, 128, 128]), # Roads | |
'brown': cp.array([0, 140, 255]), # Terrain | |
'white': cp.array([255, 255, 255]), # Buildings | |
'salmon': cp.array([128, 128, 255]) # Roofs | |
} | |
self.min_area_for_clustering = 1000 | |
self.residential_height_factor = 0.6 | |
self.isolation_threshold = 0.6 | |
def gpu_color_distance_hsv(pixel_hsv, reference_hsv, tolerance): | |
"""HSV color distance calculation""" | |
pixel_h = pixel_hsv[0] * 2 | |
pixel_s = pixel_hsv[1] / 255 | |
pixel_v = pixel_hsv[2] / 255 | |
# Calculate circular hue difference | |
hue_diff = cp.minimum(cp.abs(pixel_h - reference_hsv[0]), | |
360 - cp.abs(pixel_h - reference_hsv[0])) | |
# Calculate saturation and value differences with weighted importance | |
sat_diff = cp.abs(pixel_s - reference_hsv[1]) | |
val_diff = cp.abs(pixel_v - reference_hsv[2]) | |
# Combined distance check with adjusted weights | |
return cp.logical_and( | |
cp.logical_and( | |
hue_diff <= tolerance['hue'], | |
sat_diff <= tolerance['sat'] | |
), | |
val_diff <= tolerance['val'] | |
) | |
def segment_image_gpu(self, img): | |
"""GPU-accelerated image segmentation with roof detection""" | |
# Transfer image to GPU | |
gpu_img = cp.asarray(img) | |
gpu_hsv = cp.asarray(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) | |
height, width = img.shape[:2] | |
output = cp.zeros_like(gpu_img) | |
# Create a sliding window view for neighborhood analysis | |
pad = 2 | |
gpu_hsv_pad = cp.pad(gpu_hsv, ((pad, pad), (pad, pad), (0, 0)), mode='edge') | |
# Prepare flattened HSV data | |
hsv_pixels = gpu_hsv.reshape(-1, 3) | |
# Initialize masks including roofs | |
shadow_mask = cp.zeros((height * width,), dtype=bool) | |
road_mask = cp.zeros((height * width,), dtype=bool) | |
water_mask = cp.zeros((height * width,), dtype=bool) | |
roof_mask = cp.zeros((height * width,), dtype=bool) | |
# Color matching for predefined categories | |
for ref_hsv in self.shadow_colors_hsv: | |
temp_tolerance = { | |
'hue': self.shadow_tolerance['hue'] * 1.2, | |
'sat': self.shadow_tolerance['sat'] * 1.1, | |
'val': self.shadow_tolerance['val'] * 1.2 | |
} | |
shadow_mask |= self.gpu_color_distance_hsv(hsv_pixels.T, ref_hsv, temp_tolerance) | |
for ref_hsv in self.road_colors_hsv: | |
temp_tolerance = { | |
'hue': self.road_tolerance['hue'] * 1.3, | |
'sat': self.road_tolerance['sat'] * 1.2, | |
'val': self.road_tolerance['val'] | |
} | |
road_mask |= self.gpu_color_distance_hsv(hsv_pixels.T, ref_hsv, temp_tolerance) | |
for ref_hsv in self.water_colors_hsv: | |
water_mask |= self.gpu_color_distance_hsv(hsv_pixels.T, ref_hsv, self.water_tolerance) | |
# Roof detection with specific color matching | |
for ref_hsv in self.roof_colors_hsv: | |
roof_mask |= self.gpu_color_distance_hsv(hsv_pixels.T, ref_hsv, self.roof_tolerance) | |
# Normalize HSV values | |
h, s, v = hsv_pixels.T | |
h = h * 2 # Convert to 0-360 range | |
s = s / 255 | |
v = v / 255 | |
# Enhanced vegetation detection | |
vegetation_mask = ((h >= 40) & (h <= 150) & (s >= 0.15)) | |
# Refined terrain detection to avoid roof confusion | |
terrain_mask = ( | |
((h >= 15) & (h <= 35) & (s >= 0.15) & (s <= 0.6)) | # Main terrain colors | |
((h >= 25) & (h <= 40) & (s >= 0.1) & (v >= 0.5)) # Lighter terrain | |
) & ~roof_mask # Explicitly exclude roof areas | |
# Apply brightness-based corrections for roads | |
gray_mask = (s <= 0.2) & (v >= 0.4) & (v <= 0.85) | |
road_mask |= gray_mask & ~(shadow_mask | water_mask | vegetation_mask | terrain_mask | roof_mask) | |
# Enhanced shadow detection | |
dark_mask = (v <= 0.3) | |
shadow_mask |= dark_mask & ~(water_mask | road_mask | roof_mask) | |
# Building mask (everything that's not another category) | |
building_mask = ~(shadow_mask | water_mask | road_mask | vegetation_mask | terrain_mask | roof_mask) | |
# Apply masks to create output | |
output_flat = output.reshape(-1, 3) | |
output_flat[shadow_mask] = self.colors['black'] | |
output_flat[water_mask] = self.colors['blue'] | |
output_flat[road_mask] = self.colors['gray'] | |
output_flat[vegetation_mask] = self.colors['green'] | |
output_flat[terrain_mask] = self.colors['brown'] | |
output_flat[roof_mask] = self.colors['salmon'] | |
output_flat[building_mask] = self.colors['white'] | |
segmented = output.reshape(height, width, 3) | |
# Enhanced cleanup with roof consideration | |
kernel = cp.ones((3, 3), dtype=bool) | |
kernel[1, 1] = False | |
# Two-pass cleanup | |
for _ in range(2): | |
for color_name, color_value in self.colors.items(): | |
if cp.array_equal(color_value, self.colors['white']): | |
continue | |
color_mask = cp.all(segmented == color_value, axis=2) | |
dilated = binary_dilation(color_mask, structure=kernel) | |
building_pixels = cp.all(segmented == self.colors['white'], axis=2) | |
neighbor_count = binary_dilation(color_mask, structure=kernel).astype(int) | |
# Special handling for roofs - they should be more granular | |
if cp.array_equal(color_value, self.colors['salmon']): | |
surrounded = (neighbor_count >= 4) & building_pixels # Less aggressive for roofs | |
else: | |
surrounded = (neighbor_count >= 5) & building_pixels | |
segmented[surrounded] = color_value | |
return segmented | |
def estimate_heights_gpu(self, img, segmented): | |
"""GPU-accelerated height estimation with roof consideration""" | |
gpu_segmented = cp.asarray(segmented) | |
buildings_mask = cp.logical_or( | |
cp.all(gpu_segmented == self.colors['white'], axis=2), | |
cp.all(gpu_segmented == self.colors['salmon'], axis=2) | |
) | |
shadows_mask = cp.all(gpu_segmented == self.colors['black'], axis=2) | |
# Connected components labeling on GPU | |
labeled_array, num_features = cp_label(buildings_mask) | |
# Calculate areas using GPU | |
areas = cp.bincount(labeled_array.ravel())[1:] | |
max_area = cp.max(areas) if len(areas) > 0 else 1 | |
height_map = cp.zeros_like(labeled_array, dtype=cp.float32) | |
# Process each building/roof | |
for label in range(1, num_features + 1): | |
building_mask = (labeled_array == label) | |
if not cp.any(building_mask): | |
continue | |
area = areas[label-1] | |
size_factor = 0.3 + 0.7 * (area / max_area) | |
# Check if this is a roof (salmon color) | |
is_roof = cp.any(cp.all(gpu_segmented[building_mask] == self.colors['salmon'], axis=1)) | |
# Adjust height for roofs (typically smaller residential buildings) | |
if is_roof: | |
size_factor *= 0.8 # Slightly lower height for residential buildings | |
# Calculate shadow influence | |
dilated = binary_dilation(building_mask, structure=cp.ones((5,5))) | |
shadow_ratio = cp.sum(dilated & shadows_mask) / cp.sum(dilated) | |
shadow_factor = 0.2 + 0.8 * shadow_ratio | |
final_height = size_factor * shadow_factor | |
height_map[building_mask] = final_height | |
return height_map.get() * 0.25 | |
def generate_mesh_gpu(self, height_map, texture_img): | |
"""Generate 3D mesh using GPU-accelerated calculations""" | |
height_map_gpu = cp.asarray(height_map) | |
height, width = height_map.shape | |
# Generate vertex positions on GPU | |
x, z = cp.meshgrid(cp.arange(width), cp.arange(height)) | |
vertices = cp.stack([x, height_map_gpu * self.building_height, z], axis=-1) | |
vertices = vertices.reshape(-1, 3) | |
# Normalize coordinates | |
scale = max(width, height) | |
vertices[:, 0] = vertices[:, 0] / scale * 2 - (width / scale) | |
vertices[:, 2] = vertices[:, 2] / scale * 2 - (height / scale) | |
vertices[:, 1] = vertices[:, 1] * 2 - 1 | |
# Generate faces | |
i, j = cp.meshgrid(cp.arange(height-1), cp.arange(width-1), indexing='ij') | |
v0 = (i * width + j).flatten() | |
v1 = v0 + 1 | |
v2 = ((i + 1) * width + j).flatten() | |
v3 = v2 + 1 | |
faces = cp.vstack(( | |
cp.column_stack((v0, v2, v1)), | |
cp.column_stack((v1, v2, v3)) | |
)) | |
# Generate UV coordinates | |
uvs = cp.zeros((vertices.shape[0], 2)) | |
uvs[:, 0] = x.flatten() / (width - 1) | |
uvs[:, 1] = 1 - (z.flatten() / (height - 1)) | |
# Convert to CPU for mesh creation | |
vertices_cpu = vertices.get() | |
faces_cpu = faces.get() | |
uvs_cpu = uvs.get() | |
# Create mesh | |
if len(texture_img.shape) == 3 and texture_img.shape[2] == 4: | |
texture_img = cv2.cvtColor(texture_img, cv2.COLOR_BGRA2RGB) | |
elif len(texture_img.shape) == 3: | |
texture_img = cv2.cvtColor(texture_img, cv2.COLOR_BGR2RGB) | |
mesh = trimesh.Trimesh( | |
vertices=vertices_cpu, | |
faces=faces_cpu, | |
visual=trimesh.visual.TextureVisuals( | |
uv=uvs_cpu, | |
image=Image.fromarray(texture_img) | |
) | |
) | |
return mesh | |
def generate_and_process_map(prompt: str) -> tuple[str | None, np.ndarray | None]: | |
"""Generate satellite image from prompt and convert to 3D model using GPU acceleration""" | |
try: | |
# Set dimensions and device | |
width = height = 1024 | |
# Generate random seed | |
seed = np.random.randint(0, np.iinfo(np.int32).max) | |
# Set random seeds | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
# Generate satellite image using FLUX | |
generator = torch.Generator(device=device).manual_seed(seed) | |
generated_image = flux_pipe( | |
prompt=f"satellite view in the style of TOK, {prompt}", | |
width=width, | |
height=height, | |
num_inference_steps=25, | |
generator=generator, | |
guidance_scale=7.5 | |
).images[0] | |
# Convert PIL Image to OpenCV format | |
cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR) | |
# Initialize GPU-accelerated generator | |
generator = GPUSatelliteModelGenerator(building_height=0.09) | |
# Process image using GPU | |
print("Segmenting image using GPU...") | |
segmented_img = generator.segment_image_gpu(cv_image) | |
print("Estimating heights using GPU...") | |
height_map = generator.estimate_heights_gpu(cv_image, segmented_img) | |
# Generate mesh using GPU-accelerated calculations | |
print("Generating mesh using GPU...") | |
mesh = generator.generate_mesh_gpu(height_map, cv_image) | |
# Export to GLB | |
temp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(temp_dir, 'output.glb') | |
mesh.export(output_path) | |
# Save segmented image to a temporary file | |
segmented_path = os.path.join(temp_dir, 'segmented.png') | |
cv2.imwrite(segmented_path, segmented_img.get()) | |
return output_path, segmented_path | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None, None | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Text to Map") | |
gr.Markdown("Generate a 3D map from text!") | |
with gr.Row(): | |
prompt_input = gr.Text( | |
label="Enter your prompt", | |
placeholder="classic american town" | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
model_output = gr.Model3D( | |
label="Generated 3D Map", | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
) | |
with gr.Column(): | |
segmented_output = gr.Image( | |
label="Segmented Map", | |
type="filepath" | |
) | |
# Event handler | |
generate_btn.click( | |
fn=generate_and_process_map, | |
inputs=[prompt_input], | |
outputs=[model_output, segmented_output], | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
# Initialize FLUX pipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
repo_id = "black-forest-labs/FLUX.1-dev" | |
adapter_id = "jbilcke-hf/flux-satellite" | |
flux_pipe = DiffusionPipeline.from_pretrained( | |
repo_id, | |
torch_dtype=torch.bfloat16 | |
) | |
flux_pipe.load_lora_weights(adapter_id) | |
flux_pipe = flux_pipe.to(device) | |
# Launch Gradio app | |
demo.queue().launch() |