Spaces:
Runtime error
Runtime error
File size: 6,181 Bytes
96bae15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# app.py
import os
import re
import io
import torch
from typing import List, Optional
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from PIL import Image, ImageEnhance, ImageOps
import torchvision.transforms as T
import gradio as gr
from fastapi import Request
from starlette.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# ========== LOAD MODELS (once) ==========
print("Loading VinTern model...")
vintern_model = AutoModel.from_pretrained(
"5CD-AI/Vintern-1B-v3_5",
trust_remote_code=True,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True
).eval()
vintern_tokenizer = AutoTokenizer.from_pretrained(
"5CD-AI/Vintern-1B-v3_5",
trust_remote_code=True
)
print("VinTern loaded!")
print("Loading PhoBERT model...")
phobert_path = "DuyKien016/phobert-scam-detector"
phobert_tokenizer = AutoTokenizer.from_pretrained(phobert_path, use_fast=False)
phobert_model = AutoModelForSequenceClassification.from_pretrained(phobert_path).eval()
phobert_model = phobert_model.to("cuda" if torch.cuda.is_available() else "cpu")
print("PhoBERT loaded!")
# ========== UTILS ==========
def process_image_pil(pil_img: Image.Image):
img = pil_img.convert("RGB")
img = ImageEnhance.Contrast(img).enhance(1.8)
img = ImageEnhance.Sharpness(img).enhance(1.3)
max_size = (448, 448)
img.thumbnail(max_size, Image.Resampling.LANCZOS)
img = ImageOps.pad(img, max_size, color=(245, 245, 245))
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
pixel_values = transform(img).unsqueeze(0).to(vintern_model.device)
return pixel_values
def extract_messages(pixel_values) -> List[str]:
prompt = """<image>
Đọc từng tin nhắn trong ảnh và xuất ra định dạng:
Tin nhắn 1: [nội dung]
Tin nhắn 2: [nội dung]
Tin nhắn 3: [nội dung]
Quy tắc:
- Mỗi ô chat = 1 tin nhắn
- Chỉ lấy nội dung văn bản
- Bỏ thời gian, tên người, emoji
- Đọc từ trên xuống dưới
Bắt đầu:"""
response, *_ = vintern_model.chat(
tokenizer=vintern_tokenizer,
pixel_values=pixel_values,
question=prompt,
generation_config=dict(max_new_tokens=1024, do_sample=False, num_beams=1, early_stopping=True),
history=None,
return_history=True
)
messages = re.findall(r"Tin nhắn \d+: (.+?)(?=\nTin nhắn|\Z)", response, re.S)
def quick_clean(msg):
msg = re.sub(r"\s+", " ", msg.strip())
msg = re.sub(r'^\d+[\.\)\-\s]+', '', msg)
return msg.strip()
return [quick_clean(msg) for msg in messages if msg.strip()]
def predict_phobert(texts: List[str]):
results = []
for text in texts:
encoded = phobert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
encoded = {k: v.to(phobert_model.device) for k, v in encoded.items()}
with torch.no_grad():
logits = phobert_model(**encoded).logits
probs = torch.softmax(logits, dim=1).squeeze()
label = torch.argmax(probs).item()
results.append({
"text": text,
"prediction": "LỪA ĐẢO" if label == 1 else "BÌNH THƯỜNG",
"confidence": f"{probs[label]*100:.2f}%"
})
return results
# ========== CORE HANDLER ==========
def handle_inference(text: Optional[str], pil_image: Optional[Image.Image]):
if (not text) and (pil_image is None):
return {"error": "No valid input provided"}, 400
if pil_image is not None:
pixel_values = process_image_pil(pil_image)
messages = extract_messages(pixel_values)
phobert_results = predict_phobert(messages)
return {"messages": phobert_results}, 200
# text only
texts = [text] if isinstance(text, str) else text
if isinstance(texts, list):
phobert_results = predict_phobert(texts)
return {"messages": phobert_results}, 200
return {"error": "Invalid input format"}, 400
# ========== GRADIO APP (UI + API) ==========
demo = gr.Blocks()
with demo:
gr.Markdown("## dunkingscam backend (HF Space) — test nhanh")
with gr.Row():
txt = gr.Textbox(label="Text (tùy chọn)")
img = gr.Image(label="Ảnh chat (tùy chọn)", type="pil")
out = gr.JSON(label="Kết quả")
def ui_process(text, image):
data, _ = handle_inference(text, image)
return data
btn = gr.Button("Process")
btn.click(fn=ui_process, inputs=[txt, img], outputs=out)
# Lấy FastAPI app bên trong Gradio để thêm CORS + custom route
app = demo.server_app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # cần mở cho Replit
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Custom REST endpoint /process (FormData hoặc JSON)
@demo.add_server_route("/process", methods=["POST"])
async def process_endpoint(request: Request):
try:
ct = request.headers.get("content-type", "")
if "multipart/form-data" in ct:
form = await request.form()
text = form.get("text")
file = form.get("image") # UploadFile hoặc None
pil_image = None
if file is not None:
# đọc bytes -> PIL
content = await file.read()
pil_image = Image.open(io.BytesIO(content))
data, status = handle_inference(text, pil_image)
elif "application/json" in ct:
payload = await request.json()
text = payload.get("text")
data, status = handle_inference(text, None)
else:
data, status = {"error": "Unsupported Content-Type"}, 400
return JSONResponse(
content=data,
status_code=status,
headers={"Access-Control-Allow-Origin": "*"}
)
except Exception as e:
return JSONResponse(
content={"error": f"Server error: {str(e)}"},
status_code=500,
headers={"Access-Control-Allow-Origin": "*"}
)
|