adamantix's picture
Update app.py
aa99e27 verified
import io
import pickle
import numpy as np
import torch
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import open_clip
import re
device = "cuda" if torch.cuda.is_available() else "cpu"
# step 1: load the models
TEXT_MODEL_NAME = "indobenchmark/indobert-large-p1"
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(device)
text_model.eval()
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms("EVA01-g-14-plus", pretrained="merged2b_s11b_b114k")
clip_model.to(device)
clip_model.eval()
with open("xgb_full.pkl", "rb") as f:
xgb_model = pickle.load(f)
with open("k-means.pkl", "rb") as f:
kmeans = pickle.load(f)
# step 2: preprocessing
def preprocess_text(text: str) -> str:
text = str(text).lower()
text = re.sub(r'http\S+|www\.\S+', '', text)
text = re.sub(r'@\w+|#\w+', '', text)
text = re.sub(r'[^a-z\s]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return " ".join(text.split())
# step 3: feature encoding (text and image)
def encode_text(text: str):
# step 3.1 preprocess text
processed = preprocess_text(text)
# step 3.2 tokenize text
tokens = tokenizer(
processed,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=128,
)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
# take the [CLS] token
out = text_model(**tokens).last_hidden_state[:, 0, :]
return out.cpu().numpy()
def encode_image(image_bytes):
# step 4.1 load the image
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# step 4.2 encode the image into a tensor (embedding image)
tensor = clip_preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
emb = clip_model.encode_image(tensor)
return emb.cpu().numpy()
app = FastAPI(
title="Multimodal Water Pollution Risk API",
description=(
"Input: text + image + geospatial + time\n"
"Model: IndoBERT + EVA-CLIP + XGBoost\n"
),
version="1.0.3",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {
"status": "OK",
"message": "Multimodal Water Pollution Risk API is running.",
"info": "Use POST /predict with text, image, and features.",
}
@app.post("/predict")
async def predict(
text: str = Form(...),
longitude: float = Form(...),
latitude: float = Form(...),
hour: int = Form(...),
dayofweek: int = Form(...),
month: int = Form(...),
image: UploadFile = File(...),
):
# 1. Encode text
text_emb = encode_text(text)
# 2. Encode image
img_bytes = await image.read()
img_emb = encode_image(img_bytes)
# 3. Generate the location cluster
location_cluster = int(kmeans.predict([[latitude, longitude]])[0])
# 4. Create feature vector
add_feats = np.array([[longitude, latitude, location_cluster, hour, dayofweek, month]], dtype=np.float32)
# 5. Early Fusion
fused = np.concatenate([img_emb, text_emb, add_feats], axis=1)
# 6. Predict
proba = xgb_model.predict_proba(fused)[0]
pred_idx = int(np.argmax(proba))
label = "KRITIS" if pred_idx == 1 else "WASPADA"
return {
"prediction": label,
"cluster_used": location_cluster,
"probabilities": {
"WASPADA": float(proba[0]),
"KRITIS": float(proba[1])
}
}
@app.post("/predict_proba")
async def predict_proba(
text: str = Form(...),
longitude: float = Form(...),
latitude: float = Form(...),
hour: int = Form(...),
dayofweek: int = Form(...),
month: int = Form(...),
image: UploadFile = File(...),
):
text_emb = encode_text(text)
img_bytes = await image.read()
img_emb = encode_image(img_bytes)
location_cluster = int(kmeans.predict([[latitude, longitude]])[0])
add_feats = np.array([[longitude, latitude, location_cluster, hour, dayofweek, month]], dtype=np.float32)
fused = np.concatenate([img_emb, text_emb, add_feats], axis=1)
proba = xgb_model.predict_proba(fused)[0]
return {
"WASPADA": float(proba[0]),
"KRITIS": float(proba[1]),
"cluster_used": location_cluster,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)