Spaces:
Running
Running
import gradio as gr | |
import os | |
import re | |
import time | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import requests | |
import easyocr | |
from transformers import AutoTokenizer | |
from torchvision import transforms | |
from torchvision import models | |
from torchvision.transforms import functional as F | |
import pandas as pd | |
from huggingface_hub import hf_hub_download | |
import warnings | |
warnings.filterwarnings("ignore") | |
# --- Setup --- | |
# Device setup | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained('indobenchmark/indobert-base-p1') | |
# Image transformation | |
class ResizePadToSquare: | |
def __init__(self, target_size=300): | |
self.target_size = target_size | |
def __call__(self, img): | |
img = img.convert("RGB") | |
img.thumbnail((self.target_size, self.target_size), Image.BILINEAR) | |
delta_w = self.target_size - img.size[0] | |
delta_h = self.target_size - img.size[1] | |
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2) | |
img = F.pad(img, padding, fill=0, padding_mode='constant') | |
return img | |
transform = transforms.Compose([ | |
ResizePadToSquare(300), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
# Screenshot folder | |
SCREENSHOT_DIR = "screenshots" | |
os.makedirs(SCREENSHOT_DIR, exist_ok=True) | |
# Create OCR reader | |
reader = easyocr.Reader(['id']) # Indonesia language | |
print("OCR reader initialized.") | |
# --- Model --- | |
class LateFusionModel(nn.Module): | |
def __init__(self, image_model, text_model): | |
super(LateFusionModel, self).__init__() | |
self.image_model = image_model | |
self.text_model = text_model | |
self.image_weight = nn.Parameter(torch.tensor(0.5)) | |
self.text_weight = nn.Parameter(torch.tensor(0.5)) | |
def forward(self, images, input_ids, attention_mask): | |
with torch.no_grad(): | |
image_logits = self.image_model(images).squeeze(1) | |
text_logits = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits.squeeze(1) | |
weights = torch.softmax(torch.stack([self.image_weight, self.text_weight]), dim=0) | |
fused_logits = weights[0] * image_logits + weights[1] * text_logits | |
return fused_logits, image_logits, text_logits, weights | |
# Load model | |
model_path = "models/best_fusion_model.pt" | |
if os.path.exists(model_path): | |
fusion_model = torch.load(model_path, map_location=device, weights_only=False) | |
else: | |
model_path = hf_hub_download(repo_id="azzandr/gambling-fusion-model", filename="best_fusion_model.pt") | |
fusion_model = torch.load(model_path, map_location=device, weights_only=False) | |
# fusion_model = unwrap_dataparallel(fusion_model) | |
fusion_model.to(device) | |
fusion_model.eval() | |
print("Fusion model loaded successfully!") | |
# Load Image-Only Model | |
# Load image model from state_dict | |
image_model_path = "models/best_image_model_Adam_lr0.0001_bs32_state_dict.pt" | |
if os.path.exists(image_model_path): | |
image_only_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT) | |
num_features = image_only_model.classifier[1].in_features | |
image_only_model.classifier = nn.Linear(num_features, 1) | |
image_only_model.load_state_dict(torch.load(image_model_path, map_location=device)) | |
image_only_model.to(device) | |
image_only_model.eval() | |
print("Image-only model loaded from state_dict successfully!") | |
else: | |
print("Image-only model not found locally. Downloading from Hugging Face Hub...") | |
image_model_path = hf_hub_download(repo_id="azzandr/gambling-image-model", filename="best_image_model_Adam_lr0.0001_bs32_state_dict.pt") | |
image_only_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT) | |
num_features = image_only_model.classifier[1].in_features | |
image_only_model.classifier = nn.Linear(num_features, 1) | |
image_only_model.load_state_dict(torch.load(image_model_path, map_location=device)) | |
image_only_model.to(device) | |
image_only_model.eval() | |
print("Image-only model downloaded and loaded successfully!") | |
# --- Functions --- | |
def clean_text(text): | |
exceptions = { | |
"di", "ke", "ya" | |
} | |
# ----- BASIC CLEANING ----- | |
text = re.sub(r"http\S+", "", text) # Hapus URL | |
text = re.sub(r"\n", " ", text) # Ganti newline dengan spasi | |
text = re.sub(r"[^a-zA-Z']", " ", text) # Hanya sisakan huruf dan apostrof | |
text = re.sub(r"\s{2,}", " ", text).strip().lower() # Hapus spasi ganda, ubah ke lowercase | |
# ----- FILTERING ----- | |
words = text.split() | |
filtered_words = [ | |
w for w in words | |
if (len(w) > 2 or w in exceptions) # Simpan kata >2 huruf atau ada di exceptions | |
] | |
text = ' '.join(filtered_words) | |
# ----- REMOVE UNWANTED PATTERNS ----- | |
text = re.sub(r'\b[aeiou]+\b', '', text) # Hapus kata semua vokal (panjang berapa pun) | |
text = re.sub(r'\b[^aeiou\s]+\b', '', text) # Hapus kata semua konsonan (panjang berapa pun) | |
text = re.sub(r'\b\w{20,}\b', '', text) # Hapus kata sangat panjang (≥20 huruf) | |
text = re.sub(r'\s+', ' ', text).strip() # Bersihkan spasi ekstra | |
# check words number | |
if len(text.split()) < 5: | |
print(f"Cleaned text too short ({len(text.split())} words). Ignoring text.") | |
return "" # empty return to use image-only | |
return text | |
# Your API key | |
SCREENSHOT_API_KEY = os.getenv("SCREENSHOT_API_KEY") # Ambil dari environment variable | |
# Constants for screenshot configuration | |
CLOUDFLARE_CHECK_KEYWORDS = ["Checking your browser", "Just a moment", "Cloudflare"] | |
def ensure_http(url): | |
if not url.startswith(('http://', 'https://')): | |
return 'http://' + url | |
return url | |
def sanitize_filename(url): | |
return re.sub(r'[^\w\-_\. ]', '_', url) | |
def take_screenshot(url): | |
url = ensure_http(url) | |
filename = sanitize_filename(url) + '.png' | |
filepath = os.path.join(SCREENSHOT_DIR, filename) | |
try: | |
if not SCREENSHOT_API_KEY: | |
print("SCREENSHOT_API_KEY not found in environment.") | |
return None | |
api_url = "https://api.apiflash.com/v1/urltoimage" | |
# Base parameters - only using supported parameters | |
params = { | |
"access_key": SCREENSHOT_API_KEY, | |
"url": url, | |
"format": "png", | |
"wait_until": "network_idle", | |
"delay": 2, | |
"fail_on_status": "400,401,402,403,404,500,502,503,504", | |
"fresh": "true", # Don't use cached version | |
"response_type": "image", | |
"wait_for": "body" # Wait for body to be present | |
} | |
print(f"Taking screenshot of: {url}") | |
response = requests.get(api_url, params=params) | |
if response.status_code == 200: | |
# Check if response is actually an image | |
if response.headers.get('content-type', '').startswith('image'): | |
with open(filepath, 'wb') as f: | |
f.write(response.content) | |
print(f"Screenshot taken successfully for URL: {url}") | |
return filepath | |
else: | |
print(f"API returned non-image content") | |
return None | |
else: | |
error_msg = response.text | |
print(f"Screenshot failed: {error_msg}") | |
# Check for Cloudflare detection | |
if any(keyword.lower() in error_msg.lower() for keyword in CLOUDFLARE_CHECK_KEYWORDS): | |
print("Cloudflare challenge detected, retrying with different parameters...") | |
# Retry with different parameters for Cloudflare | |
params.update({ | |
"wait_until": "load", | |
"delay": 5 | |
}) | |
response = requests.get(api_url, params=params) | |
if response.status_code == 200 and response.headers.get('content-type', '').startswith('image'): | |
with open(filepath, 'wb') as f: | |
f.write(response.content) | |
print(f"Screenshot taken successfully after Cloudflare retry") | |
return filepath | |
return None | |
except Exception as e: | |
print(f"Error taking screenshot: {e}") | |
return None | |
def resize_if_needed(image_path, max_mb=1, target_height=720): | |
file_size = os.path.getsize(image_path) / (1024 * 1024) # dalam MB | |
if file_size > max_mb: | |
try: | |
with Image.open(image_path) as img: | |
width, height = img.size | |
if height > target_height: | |
ratio = target_height / float(height) | |
new_width = int(float(width) * ratio) | |
img = img.resize((new_width, target_height), Image.Resampling.LANCZOS) | |
img.save(image_path, optimize=True, quality=85) | |
print(f"Image resized to {new_width}x{target_height}") | |
except Exception as e: | |
print(f"Resize error: {e}") | |
def easyocr_extract(image_path): | |
try: | |
results = reader.readtext(image_path, detail=0) | |
text = " ".join(results) | |
print(f"OCR text extracted from EasyOCR: {len(text)} characters") | |
return text.strip() | |
except Exception as e: | |
print(f"EasyOCR error: {e}") | |
return "" | |
# def extract_text_from_image(image_path): | |
# print("Skipping OCR. Forcing Image-Only prediction.") | |
# return "" | |
def extract_text_from_image(image_path): | |
try: | |
resize_if_needed(image_path, max_mb=1, target_height=720) # Tambahkan ini di awal | |
file_size = os.path.getsize(image_path) / (1024 * 1024) # ukuran MB | |
if file_size < 1: | |
print(f"Using OCR.Space API for image ({file_size:.2f} MB)") | |
api_key = os.getenv("OCR_SPACE_API_KEY") | |
if not api_key: | |
print("OCR_SPACE_API_KEY not found in environment. Using EasyOCR as fallback.") | |
return easyocr_extract(image_path) | |
with open(image_path, 'rb') as f: | |
payload = { | |
'isOverlayRequired': False, | |
'apikey': api_key, | |
'language': 'eng' | |
} | |
r = requests.post('https://api.ocr.space/parse/image', | |
files={'filename': f}, | |
data=payload) | |
result = r.json() | |
if result.get('IsErroredOnProcessing', False): | |
print(f"OCR.Space API Error: {result.get('ErrorMessage')}") | |
return easyocr_extract(image_path) | |
text = result['ParsedResults'][0]['ParsedText'] | |
print(f"OCR text extracted from OCR.Space: {len(text)} characters") | |
return text.strip() | |
else: | |
print(f"Using EasyOCR for image ({file_size:.2f} MB)") | |
return easyocr_extract(image_path) | |
except Exception as e: | |
print(f"OCR error: {e}") | |
return "" | |
def prepare_data_for_model(image_path, text): | |
image = Image.open(image_path) | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
clean_text_data = clean_text(text) | |
encoding = tokenizer.encode_plus( | |
clean_text_data, | |
add_special_tokens=True, | |
max_length=128, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
return image_tensor, input_ids, attention_mask | |
def predict_single_url(url): | |
print(f"Processing URL: {url}") | |
screenshot_path = take_screenshot(url) | |
if not screenshot_path: | |
return f"❌ Error: Unable to capture screenshot for {url}. This may be due to:\n• Too many redirects\n• Website blocking automated access\n• Network connectivity issues\n• Invalid URL", "Screenshot capture failed", None, "", "" | |
text = extract_text_from_image(screenshot_path) | |
raw_text = text # Store raw text before cleaning | |
if not text.strip(): # Jika text kosong | |
print(f"No OCR text found for {url}. Using Image-Only Model.") | |
image = Image.open(screenshot_path) | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_logits = image_only_model(image_tensor).squeeze(1) | |
image_probs = torch.sigmoid(image_logits) | |
threshold = 0.6 | |
is_gambling = image_probs[0] > threshold | |
label = "Gambling" if is_gambling else "Non-Gambling" | |
confidence = image_probs[0].item() if is_gambling else 1 - image_probs[0].item() | |
print(f"[Image-Only] URL: {url}") | |
print(f"Prediction: {label} | Confidence: {confidence:.2f}\n") | |
return label, f"Confidence: {confidence:.2f} (Image-Only Model)", screenshot_path, raw_text, "" | |
else: | |
clean_text_data = clean_text(text) | |
image_tensor, input_ids, attention_mask = prepare_data_for_model(screenshot_path, text) | |
with torch.no_grad(): | |
fused_logits, image_logits, text_logits, weights = fusion_model(image_tensor, input_ids, attention_mask) | |
fused_probs = torch.sigmoid(fused_logits) | |
image_probs = torch.sigmoid(image_logits) | |
text_probs = torch.sigmoid(text_logits) | |
threshold = 0.6 | |
is_gambling = fused_probs[0] > threshold | |
label = "Gambling" if is_gambling else "Non-Gambling" | |
confidence = fused_probs[0].item() if is_gambling else 1 - fused_probs[0].item() | |
# ✨ Log detail | |
print(f"[Fusion Model] URL: {url}") | |
print(f"Image Model Prediction Probability: {image_probs[0]:.2f}") | |
print(f"Text Model Prediction Probability: {text_probs[0]:.2f}") | |
print(f"Fusion Final Prediction: {label} | Confidence: {confidence:.2f}\n") | |
return label, f"Confidence: {confidence:.2f} (Fusion Model)", screenshot_path, raw_text, clean_text_data | |
def predict_batch_urls(file_obj): | |
results = [] | |
content = file_obj.read().decode('utf-8') | |
urls = [line.strip() for line in content.splitlines() if line.strip()] | |
for url in urls: | |
label, confidence, screenshot_path, raw_text, cleaned_text = predict_single_url(url) | |
results.append({ | |
"url": url, | |
"label": label, | |
"confidence": confidence, | |
"screenshot_path": screenshot_path, | |
"raw_text": raw_text, | |
"cleaned_text": cleaned_text | |
}) | |
df = pd.DataFrame(results) | |
print(f"Batch prediction completed for {len(urls)} URLs.") | |
return df | |
# --- Gradio App --- | |
with gr.Blocks() as app: | |
gr.Markdown("# 🕵️ Gambling Website Detection (URL Based)") | |
with gr.Tab("Single URL"): | |
url_input = gr.Textbox(label="Enter Website URL") | |
predict_button = gr.Button("Predict") | |
with gr.Row(): | |
with gr.Column(): | |
label_output = gr.Label() | |
confidence_output = gr.Textbox(label="Confidence", interactive=False) | |
with gr.Column(): | |
screenshot_output = gr.Image(label="Screenshot", type="filepath") | |
with gr.Row(): | |
with gr.Column(): | |
raw_text_output = gr.Textbox(label="Raw OCR Text", lines=5) | |
with gr.Column(): | |
cleaned_text_output = gr.Textbox(label="Cleaned Text", lines=5) | |
predict_button.click( | |
fn=predict_single_url, | |
inputs=url_input, | |
outputs=[ | |
label_output, | |
confidence_output, | |
screenshot_output, | |
raw_text_output, | |
cleaned_text_output | |
] | |
) | |
with gr.Tab("Batch URLs"): | |
file_input = gr.File(label="Upload .txt file with URLs (one per line)") | |
batch_predict_button = gr.Button("Batch Predict") | |
batch_output = gr.DataFrame() | |
batch_predict_button.click(fn=predict_batch_urls, inputs=file_input, outputs=batch_output) | |
app.launch() |