QDF-large / app.py
dejanseo's picture
Update app.py
856d741 verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import os
# Theme configuration - MUST BE FIRST STREAMLIT COMMAND
st.set_page_config(
page_title="QDF Classifier",
page_icon="🔍",
layout="wide",
initial_sidebar_state="collapsed",
menu_items=None
)
MODEL_ID = "dejanseo/QDF-large"
HF_TOKEN = os.getenv("HF_TOKEN")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
low_cpu_mem_usage=True
).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
def classify(prompt: str):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze().cpu()
pred = torch.argmax(probs).item()
confidence = probs[pred].item()
return pred, confidence
# Font and style overrides
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap');
html, body, div, span, input, label, textarea, button, h1, h2, p, table {
font-family: 'Montserrat', sans-serif !important;
}
[class^="css-"], [class*=" css-"] {
font-family: 'Montserrat', sans-serif !important;
}
header {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
# UI
st.title("QDF Classifier")
st.write("Built by [**Dejan AI**](https://dejan.ai)")
st.write("This classifier determines whether query deserves freshness.")
# Placeholder example prompts
example_text = """how would a cat describe a dog
how to reset a Nest thermostat
write a poem about time
is there a train strike in London today
summarize the theory of relativity
who won the champions league last year
explain quantum computing to a child
weather in tokyo tomorrow
generate a social media post for Earth Day
what is the latest iPhone model"""
user_input = st.text_area(
"Enter one search query per line:",
placeholder=example_text
)
if st.button("Classify"):
raw_input = user_input.strip()
if raw_input:
prompts = [line.strip() for line in raw_input.split("\n") if line.strip()]
else:
prompts = [line.strip() for line in example_text.split("\n")]
if not prompts:
st.warning("Please enter at least one prompt.")
else:
info_box = st.info("Processing... results will appear below one by one.")
table_placeholder = st.empty()
results = []
for p in prompts:
with st.spinner(f"Classifying: {p[:50]}..."):
label, conf = classify(p)
results.append({
"Prompt": p,
"QDF": "Yes" if label == 1 else "No",
"Confidence": round(conf, 4)
})
df = pd.DataFrame(results)
table_placeholder.data_editor(
df,
column_config={
"Confidence": st.column_config.ProgressColumn(
label="Confidence",
min_value=0.0,
max_value=1.0,
format="%.4f"
)
},
hide_index=True,
)
info_box.empty()
# Promo message shown only after results
st.subheader("Working together.")
st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.")