File size: 3,728 Bytes
d1c870a 8fdff7e d1c870a 1b55ac6 86b4e16 d1c870a 0ae91f3 d1c870a 0ae91f3 f693b94 d1c870a 643c698 d1c870a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import os
# Theme configuration - MUST BE FIRST STREAMLIT COMMAND
st.set_page_config(
page_title="QDF Classifier",
page_icon="🔍",
layout="wide",
initial_sidebar_state="collapsed",
menu_items=None
)
MODEL_ID = "dejanseo/QDF-large"
HF_TOKEN = os.getenv("HF_TOKEN")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
low_cpu_mem_usage=True
).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
def classify(prompt: str):
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze().cpu()
pred = torch.argmax(probs).item()
confidence = probs[pred].item()
return pred, confidence
# Font and style overrides
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap');
html, body, div, span, input, label, textarea, button, h1, h2, p, table {
font-family: 'Montserrat', sans-serif !important;
}
[class^="css-"], [class*=" css-"] {
font-family: 'Montserrat', sans-serif !important;
}
header {visibility: hidden;}
</style>
""", unsafe_allow_html=True)
# UI
st.title("QDF Classifier")
st.write("Built by [**Dejan AI**](https://dejan.ai)")
st.write("This classifier determines whether query deserves freshness.")
# Placeholder example prompts
example_text = """how would a cat describe a dog
how to reset a Nest thermostat
write a poem about time
is there a train strike in London today
summarize the theory of relativity
who won the champions league last year
explain quantum computing to a child
weather in tokyo tomorrow
generate a social media post for Earth Day
what is the latest iPhone model"""
user_input = st.text_area(
"Enter one search query per line:",
placeholder=example_text
)
if st.button("Classify"):
raw_input = user_input.strip()
if raw_input:
prompts = [line.strip() for line in raw_input.split("\n") if line.strip()]
else:
prompts = [line.strip() for line in example_text.split("\n")]
if not prompts:
st.warning("Please enter at least one prompt.")
else:
info_box = st.info("Processing... results will appear below one by one.")
table_placeholder = st.empty()
results = []
for p in prompts:
with st.spinner(f"Classifying: {p[:50]}..."):
label, conf = classify(p)
results.append({
"Prompt": p,
"QDF": "Yes" if label == 1 else "No",
"Confidence": round(conf, 4)
})
df = pd.DataFrame(results)
table_placeholder.data_editor(
df,
column_config={
"Confidence": st.column_config.ProgressColumn(
label="Confidence",
min_value=0.0,
max_value=1.0,
format="%.4f"
)
},
hide_index=True,
)
info_box.empty()
# Promo message shown only after results
st.subheader("Working together.")
st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.")
|