File size: 6,500 Bytes
3f33ff0
39c0356
3f33ff0
 
 
 
 
 
 
 
39c0356
3f33ff0
 
 
 
39c0356
3f33ff0
 
 
 
 
 
 
 
39c0356
3f33ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c0356
3f33ff0
 
 
 
39c0356
3f33ff0
39c0356
 
 
 
3f33ff0
 
 
39c0356
aa6ce0d
3f33ff0
aa6ce0d
3f33ff0
 
 
 
 
39c0356
3f33ff0
 
 
 
 
 
 
afacd96
39c0356
 
afacd96
39c0356
afacd96
39c0356
3f33ff0
39c0356
3f33ff0
39c0356
afacd96
 
39c0356
 
 
 
 
 
 
 
afacd96
39c0356
aa6ce0d
afacd96
aa6ce0d
 
 
3f33ff0
 
 
39c0356
3f33ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c0356
3f33ff0
39c0356
3f33ff0
 
39c0356
3f33ff0
 
 
 
39c0356
3f33ff0
 
39c0356
3f33ff0
 
 
 
39c0356
 
3f33ff0
39c0356
3f33ff0
 
 
39c0356
3f33ff0
 
 
 
39c0356
 
 
 
 
 
 
3f33ff0
 
 
 
 
 
39c0356
3f33ff0
 
 
39c0356
3f33ff0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import io
import os
import glob
import base64
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
import uvicorn
from PIL import Image, ImageOps
import torchvision.transforms as T

# ================= 1. MODEL ARCHITECTURE =================
class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = self.down_block(4, 64, bn=False)
        self.e2 = self.down_block(64, 128)
        self.e3 = self.down_block(128, 256)
        self.e4 = self.down_block(256, 512)
        self.b = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1), nn.ReLU(inplace=True)
        )
        self.up1 = self.up_block(512, 512)
        self.up2 = self.up_block(1024, 256)
        self.up3 = self.up_block(512, 128)
        self.up4 = self.up_block(256, 64)
        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 3, 3, 1, 1),
            nn.Sigmoid()
        )

    def down_block(self, in_c, out_c, bn=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if bn: layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def up_block(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1), 
            nn.BatchNorm2d(out_c), nn.ReLU()
        )

    def forward(self, masked_img, mask):
        x = torch.cat([masked_img, mask], dim=1) 
        e1 = self.e1(x); e2 = self.e2(e1); e3 = self.e3(e2); e4 = self.e4(e3)
        b = self.b(e4)
        d1 = self.up1(b); d1 = torch.cat([d1, e4], dim=1) 
        d2 = self.up2(d1); d2 = torch.cat([d2, e3], dim=1) 
        d3 = self.up3(d2); d3 = torch.cat([d3, e2], dim=1) 
        d4 = self.up4(d3); d4 = torch.cat([d4, e1], dim=1) 
        return self.final(d4)           

class MagicEraserGAN(pl.LightningModule):
    # FIX: Accept **kwargs to ignore 'lr' and other training params
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters() 
        self.generator = UNetGenerator()

    def forward(self, masked_img, mask):
        return self.generator(masked_img, mask)

# ================= 2. SERVER SETUP =================
app = FastAPI()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None

@app.on_event("startup")
def load_model():
    global model
    
    # === CONFIGURATION ===
    # Replace with your Space ID! (e.g. "ayushpandey/magic-eraser")
    REPO_ID = "ayushpfullstack/MagicEraser" 
    # Filename inside the checkpoints folder
    FILENAME = "checkpoints/magic-eraser-a100-epoch=19.ckpt" 
    # =====================

    print(f"⬇️ Checking for model at {REPO_ID}/{FILENAME}...")
    try:
        # Download (or find in cache) the model file
        checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="space")
        
        # Check size to prevent "Pointer File" errors
        file_size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
        if file_size_mb < 10:
            print(f"❌ FATAL ERROR: The model file is too small ({file_size_mb:.4f} MB).")
            print("It is likely a Git LFS pointer. Please delete and upload the real file via the website.")
            return

        print(f"📦 Loading model from {checkpoint_path} ({file_size_mb:.1f} MB)...")
        
        # Load with strict=False to ignore Discriminator weights
        model = MagicEraserGAN.load_from_checkpoint(
            checkpoint_path, 
            strict=False, 
            map_location=DEVICE
        )
        model.to(DEVICE)
        model.eval()
        print("✅ Model loaded successfully!")

    except Exception as e:
        print(f"❌ Failed to load model: {e}")

def pad_to_multiple(tensor, multiple=16):
    _, _, h, w = tensor.shape
    ph = (multiple - h % multiple) % multiple
    pw = (multiple - w % multiple) % multiple
    if ph > 0 or pw > 0:
        tensor = F.pad(tensor, (0, pw, 0, ph), mode='reflect')
    return tensor, h, w

def tensor_to_base64(tensor):
    tensor = torch.clamp(tensor, 0, 1)
    to_pil = T.ToPILImage()
    image = to_pil(tensor.squeeze(0).cpu())
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"

# ================= 3. API ENDPOINTS =================
@app.get("/")
def home():
    return {"status": "Online", "device": str(DEVICE)}

@app.post("/erase")
async def erase_object(image: UploadFile = File(...), mask: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=500, detail="Model is not loaded. Check server logs.")

    try:
        # 1. Read Files
        img_bytes = await image.read()
        mask_bytes = await mask.read()
        
        orig_pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        # Ensure Mask is Grayscale (L). React Native sends White strokes on Black BG.
        mask_pil = Image.open(io.BytesIO(mask_bytes)).convert("L") 

        # 2. Resize
        if orig_pil.size != mask_pil.size:
            mask_pil = mask_pil.resize(orig_pil.size)

        # 3. Process Mask
        to_tensor = T.ToTensor()
        img_tensor = to_tensor(orig_pil).unsqueeze(0).to(DEVICE)
        mask_tensor = to_tensor(mask_pil).unsqueeze(0).to(DEVICE)

        # Logic: 
        # App sends: 1 (White) where user drew.
        # Model needs: 0 (Black) for HOLE, 1 for KEEP.
        mask_binary = (mask_tensor > 0.5).float() # Clean binary
        final_mask = 1 - mask_binary              # Invert

        # 4. Inference
        img_padded, h, w = pad_to_multiple(img_tensor)
        mask_padded, _, _ = pad_to_multiple(final_mask)

        with torch.no_grad():
            masked_input = img_padded * mask_padded
            generated = model(masked_input, mask_padded)
            # Combine: Kept parts + Generated parts
            result = img_padded * mask_padded + generated * (1 - mask_padded)
            result = result[:, :, :h, :w]

        return JSONResponse(content={"result_image": tensor_to_base64(result)})

    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)