TLSIM / streamlit_app.py
kyrilloswahid's picture
Update streamlit_app.py
53cc87a verified
# streamlit_app.py
import streamlit as st
import numpy as np
import tensorflow as tf
import cv2
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.xception import preprocess_input as xcp_pre
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_pre
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import PlainTextResponse
import uvicorn
from io import BytesIO
# Set up Streamlit UI
st.set_page_config(page_title="Deepfake Image Verifier", layout="centered")
st.title("πŸ” Deepfake Image Verifier")
st.markdown("Upload an image to classify it as **Real** or **Fake** using an ensemble of Xception and EfficientNet models.")
# Load models from HF Hub once
@st.cache_resource
def load_models():
xcp_path = hf_hub_download(repo_id="Zeyadd-Mostaffa/deepfake-image-detector_final", filename="xception_model.h5")
eff_path = hf_hub_download(repo_id="Zeyadd-Mostaffa/deepfake-image-detector_final", filename="efficientnet_model.h5")
xcp_model = load_model(xcp_path)
eff_model = load_model(eff_path)
return xcp_model, eff_model
xcp_model, eff_model = load_models()
# Prediction logic
def run_model_prediction(image_np):
xcp_img = cv2.resize(image_np, (299, 299))
eff_img = cv2.resize(image_np, (224, 224))
xcp_tensor = xcp_pre(xcp_img.astype(np.float32))[np.newaxis, ...]
eff_tensor = eff_pre(eff_img.astype(np.float32))[np.newaxis, ...]
xcp_pred = xcp_model.predict(xcp_tensor, verbose=0).flatten()[0]
eff_pred = eff_model.predict(eff_tensor, verbose=0).flatten()[0]
avg_pred = (xcp_pred + eff_pred) / 2
label = "Real" if avg_pred > 0.5 else "Fake"
return label
# Streamlit UI
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
st.image(image_rgb, caption="Uploaded Image", use_column_width=True)
with st.spinner("Analyzing..."):
label = run_model_prediction(image_rgb)
st.success(f"Prediction: **{label}**")
# FastAPI for backend use (Flask calls etc.)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict_api(file: UploadFile = File(...)):
try:
contents = await file.read()
file_bytes = np.asarray(bytearray(contents), dtype=np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = run_model_prediction(image_rgb)
return PlainTextResponse(label, status_code=200)
except Exception as e:
return PlainTextResponse(f"Error: {str(e)}", status_code=500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)