|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import pandas as pd |
|
import os |
|
|
|
|
|
st.set_page_config( |
|
page_title="QDF Classifier", |
|
page_icon="🔍", |
|
layout="wide", |
|
initial_sidebar_state="collapsed", |
|
menu_items=None |
|
) |
|
|
|
MODEL_ID = "dejanseo/QDF-large" |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
MODEL_ID, |
|
token=HF_TOKEN, |
|
low_cpu_mem_usage=True |
|
).eval() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
|
def classify(prompt: str): |
|
inputs = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=512 |
|
) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = torch.softmax(logits, dim=-1).squeeze().cpu() |
|
pred = torch.argmax(probs).item() |
|
confidence = probs[pred].item() |
|
return pred, confidence |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap'); |
|
|
|
html, body, div, span, input, label, textarea, button, h1, h2, p, table { |
|
font-family: 'Montserrat', sans-serif !important; |
|
} |
|
|
|
[class^="css-"], [class*=" css-"] { |
|
font-family: 'Montserrat', sans-serif !important; |
|
} |
|
|
|
header {visibility: hidden;} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.title("QDF Classifier") |
|
st.write("Built by [**Dejan AI**](https://dejan.ai)") |
|
st.write("This classifier determines whether query deserves freshness.") |
|
|
|
|
|
example_text = """how would a cat describe a dog |
|
how to reset a Nest thermostat |
|
write a poem about time |
|
is there a train strike in London today |
|
summarize the theory of relativity |
|
who won the champions league last year |
|
explain quantum computing to a child |
|
weather in tokyo tomorrow |
|
generate a social media post for Earth Day |
|
what is the latest iPhone model""" |
|
|
|
user_input = st.text_area( |
|
"Enter one search query per line:", |
|
placeholder=example_text |
|
) |
|
|
|
if st.button("Classify"): |
|
raw_input = user_input.strip() |
|
if raw_input: |
|
prompts = [line.strip() for line in raw_input.split("\n") if line.strip()] |
|
else: |
|
prompts = [line.strip() for line in example_text.split("\n")] |
|
|
|
if not prompts: |
|
st.warning("Please enter at least one prompt.") |
|
else: |
|
info_box = st.info("Processing... results will appear below one by one.") |
|
table_placeholder = st.empty() |
|
results = [] |
|
|
|
for p in prompts: |
|
with st.spinner(f"Classifying: {p[:50]}..."): |
|
label, conf = classify(p) |
|
results.append({ |
|
"Prompt": p, |
|
"QDF": "Yes" if label == 1 else "No", |
|
"Confidence": round(conf, 4) |
|
}) |
|
df = pd.DataFrame(results) |
|
table_placeholder.data_editor( |
|
df, |
|
column_config={ |
|
"Confidence": st.column_config.ProgressColumn( |
|
label="Confidence", |
|
min_value=0.0, |
|
max_value=1.0, |
|
format="%.4f" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
info_box.empty() |
|
|
|
|
|
st.subheader("Working together.") |
|
st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.") |
|
|