import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
import logging # Optional: Add logging for better debugging
# Set up logging (optional but helpful)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Streamlit page config
st.set_page_config(
page_title="AI Article Detection by DEJAN",
page_icon="🧠",
layout="wide"
)
# Logo as provided
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
)
# Font styling
st.markdown("""
""", unsafe_allow_html=True)
@st.cache_resource # Cache the model and tokenizer to avoid reloading on every interaction
def load_model_and_tokenizer(model_name):
"""Loads the model and tokenizer."""
logger.info(f"Loading tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (device.type == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
logger.info(f"Using device: {device} with dtype: {dtype}")
logger.info(f"Loading model: {model_name}")
model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=dtype)
model.to(device)
model.eval()
logger.info("Model loaded successfully.")
return tokenizer, model, device
MODEL_NAME = "dejanseo/ai-detection-small"
try:
tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
except Exception as e:
st.error(f"Error loading model: {e}")
logger.error(f"Failed to load model or tokenizer: {e}", exc_info=True)
st.stop()
# Labels
LABELS = ["AI Content", "Human Content"]
# Regex-based sentence splitter
def sent_tokenize(text):
sentences = re.split(r'(?<=[\.!?])\s+', text.strip())
return [s for s in sentences if s]
# UI
st.title("AI Article Detection")
text = st.text_area("Enter text to classify", height=200, placeholder="Paste your text here...")
if st.button("Classify", type="primary"):
if not text or not text.strip():
st.warning("Please enter some text.")
else:
with st.spinner("Analyzing... Please wait."):
try:
sentences = sent_tokenize(text)
if not sentences:
st.warning("No sentences detected.")
st.stop()
# Tokenize sentences
inputs = tokenizer(
sentences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=model.config.max_position_embeddings
).to(device)
# Inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1).cpu() # [n_sentences, 2]
preds = torch.argmax(probs, dim=-1).cpu()
# Build inline styled text
styled_chunks = []
for i, sent in enumerate(sentences):
pred = preds[i].item()
# red for AI (class 0), green for Human (class 1)
r, g = (255, 0) if pred == 0 else (0, 255)
confidence = probs[i, pred].item() # 0.0–1.0
alpha = confidence # opacity
span = (
f""
f"{sent}"
f""
)
styled_chunks.append(span)
full_text_html = "".join(styled_chunks)
st.markdown(full_text_html, unsafe_allow_html=True)
# Overall AI likelihood (class 0)
avg_probs = torch.mean(probs, dim=0)
ai_likelihood = avg_probs[0].item() * 100
st.subheader(f"🤖 AI Likelihood: {ai_likelihood:.1f}%")
except Exception as e:
st.error(f"An error occurred during analysis: {e}")
logger.error("Analysis failed", exc_info=True)