Spaces:
Sleeping
Sleeping
File size: 6,500 Bytes
3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 aa6ce0d 3f33ff0 aa6ce0d 3f33ff0 39c0356 3f33ff0 afacd96 39c0356 afacd96 39c0356 afacd96 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 afacd96 39c0356 afacd96 39c0356 aa6ce0d afacd96 aa6ce0d 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 39c0356 3f33ff0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import io
import os
import glob
import base64
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
import uvicorn
from PIL import Image, ImageOps
import torchvision.transforms as T
# ================= 1. MODEL ARCHITECTURE =================
class UNetGenerator(nn.Module):
def __init__(self):
super().__init__()
self.e1 = self.down_block(4, 64, bn=False)
self.e2 = self.down_block(64, 128)
self.e3 = self.down_block(128, 256)
self.e4 = self.down_block(256, 512)
self.b = nn.Sequential(
nn.Conv2d(512, 512, 4, 2, 1), nn.ReLU(inplace=True)
)
self.up1 = self.up_block(512, 512)
self.up2 = self.up_block(1024, 256)
self.up3 = self.up_block(512, 128)
self.up4 = self.up_block(256, 64)
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 3, 3, 1, 1),
nn.Sigmoid()
)
def down_block(self, in_c, out_c, bn=True):
layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
if bn: layers.append(nn.BatchNorm2d(out_c))
layers.append(nn.LeakyReLU(0.2))
return nn.Sequential(*layers)
def up_block(self, in_c, out_c):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c), nn.ReLU()
)
def forward(self, masked_img, mask):
x = torch.cat([masked_img, mask], dim=1)
e1 = self.e1(x); e2 = self.e2(e1); e3 = self.e3(e2); e4 = self.e4(e3)
b = self.b(e4)
d1 = self.up1(b); d1 = torch.cat([d1, e4], dim=1)
d2 = self.up2(d1); d2 = torch.cat([d2, e3], dim=1)
d3 = self.up3(d2); d3 = torch.cat([d3, e2], dim=1)
d4 = self.up4(d3); d4 = torch.cat([d4, e1], dim=1)
return self.final(d4)
class MagicEraserGAN(pl.LightningModule):
# FIX: Accept **kwargs to ignore 'lr' and other training params
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.generator = UNetGenerator()
def forward(self, masked_img, mask):
return self.generator(masked_img, mask)
# ================= 2. SERVER SETUP =================
app = FastAPI()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
@app.on_event("startup")
def load_model():
global model
# === CONFIGURATION ===
# Replace with your Space ID! (e.g. "ayushpandey/magic-eraser")
REPO_ID = "ayushpfullstack/MagicEraser"
# Filename inside the checkpoints folder
FILENAME = "checkpoints/magic-eraser-a100-epoch=19.ckpt"
# =====================
print(f"⬇️ Checking for model at {REPO_ID}/{FILENAME}...")
try:
# Download (or find in cache) the model file
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="space")
# Check size to prevent "Pointer File" errors
file_size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
if file_size_mb < 10:
print(f"❌ FATAL ERROR: The model file is too small ({file_size_mb:.4f} MB).")
print("It is likely a Git LFS pointer. Please delete and upload the real file via the website.")
return
print(f"📦 Loading model from {checkpoint_path} ({file_size_mb:.1f} MB)...")
# Load with strict=False to ignore Discriminator weights
model = MagicEraserGAN.load_from_checkpoint(
checkpoint_path,
strict=False,
map_location=DEVICE
)
model.to(DEVICE)
model.eval()
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load model: {e}")
def pad_to_multiple(tensor, multiple=16):
_, _, h, w = tensor.shape
ph = (multiple - h % multiple) % multiple
pw = (multiple - w % multiple) % multiple
if ph > 0 or pw > 0:
tensor = F.pad(tensor, (0, pw, 0, ph), mode='reflect')
return tensor, h, w
def tensor_to_base64(tensor):
tensor = torch.clamp(tensor, 0, 1)
to_pil = T.ToPILImage()
image = to_pil(tensor.squeeze(0).cpu())
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
# ================= 3. API ENDPOINTS =================
@app.get("/")
def home():
return {"status": "Online", "device": str(DEVICE)}
@app.post("/erase")
async def erase_object(image: UploadFile = File(...), mask: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model is not loaded. Check server logs.")
try:
# 1. Read Files
img_bytes = await image.read()
mask_bytes = await mask.read()
orig_pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# Ensure Mask is Grayscale (L). React Native sends White strokes on Black BG.
mask_pil = Image.open(io.BytesIO(mask_bytes)).convert("L")
# 2. Resize
if orig_pil.size != mask_pil.size:
mask_pil = mask_pil.resize(orig_pil.size)
# 3. Process Mask
to_tensor = T.ToTensor()
img_tensor = to_tensor(orig_pil).unsqueeze(0).to(DEVICE)
mask_tensor = to_tensor(mask_pil).unsqueeze(0).to(DEVICE)
# Logic:
# App sends: 1 (White) where user drew.
# Model needs: 0 (Black) for HOLE, 1 for KEEP.
mask_binary = (mask_tensor > 0.5).float() # Clean binary
final_mask = 1 - mask_binary # Invert
# 4. Inference
img_padded, h, w = pad_to_multiple(img_tensor)
mask_padded, _, _ = pad_to_multiple(final_mask)
with torch.no_grad():
masked_input = img_padded * mask_padded
generated = model(masked_input, mask_padded)
# Combine: Kept parts + Generated parts
result = img_padded * mask_padded + generated * (1 - mask_padded)
result = result[:, :, :h, :w]
return JSONResponse(content={"result_image": tensor_to_base64(result)})
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |